dagloop5 commited on
Commit
844ace6
·
verified ·
1 Parent(s): c43c959

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -68
app.py CHANGED
@@ -267,8 +267,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
270
- GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
271
- GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
272
 
273
  # Download model checkpoints
274
  print("=" * 80)
@@ -293,72 +291,7 @@ checkpoint_path = hf_hub_download(
293
  local_dir_use_symlinks=False,
294
  )
295
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
296
-
297
- print("[Gemma] Setting up abliterated Gemma text encoder...")
298
- MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
299
- gemma_root = "/tmp/abliterated_gemma"
300
- os.makedirs(gemma_root, exist_ok=True)
301
-
302
- gemma_official_dir = snapshot_download(
303
- repo_id=GEMMA_REPO,
304
- ignore_patterns=["*.safetensors", "*.safetensors.index.json"],
305
- )
306
-
307
- for fname in os.listdir(gemma_official_dir):
308
- src = os.path.join(gemma_official_dir, fname)
309
- dst = os.path.join(gemma_root, fname)
310
- if os.path.isfile(src) and not fname.endswith(".safetensors") and fname != "model.safetensors.index.json":
311
- if not os.path.exists(dst):
312
- os.symlink(src, dst)
313
-
314
- if os.path.exists(MERGED_WEIGHTS):
315
- print("[Gemma] Using cached merged weights")
316
- else:
317
- abliterated_weights_path = hf_hub_download(
318
- repo_id=GEMMA_ABLITERATED_REPO,
319
- filename=GEMMA_ABLITERATED_FILE,
320
- )
321
- index_path = hf_hub_download(
322
- repo_id=GEMMA_REPO,
323
- filename="model.safetensors.index.json"
324
- )
325
- with open(index_path) as f:
326
- weight_index = json.load(f)
327
-
328
- vision_keys = {}
329
- for key, shard in weight_index["weight_map"].items():
330
- if "vision_tower" in key or "multi_modal_projector" in key:
331
- vision_keys[key] = shard
332
- needed_shards = set(vision_keys.values())
333
-
334
- shard_paths = {}
335
- for shard_name in needed_shards:
336
- shard_paths[shard_name] = hf_hub_download(
337
- repo_id=GEMMA_REPO,
338
- filename=shard_name
339
- )
340
-
341
- _fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
342
- raw = load_file(abliterated_weights_path)
343
- merged = {}
344
- for key, tensor in raw.items():
345
- t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
346
- merged[f"language_model.{key}"] = t
347
- del raw
348
-
349
- for key, shard_name in vision_keys.items():
350
- with safe_open(shard_paths[shard_name], framework="pt") as f:
351
- merged[key] = f.get_tensor(key)
352
-
353
- save_file(merged, MERGED_WEIGHTS)
354
- del merged
355
- gc.collect()
356
-
357
- weight_link = os.path.join(gemma_root, "model.safetensors")
358
- if os.path.exists(weight_link):
359
- os.remove(weight_link)
360
- os.symlink(MERGED_WEIGHTS, weight_link)
361
- print(f"[Gemma] Root ready: {gemma_root}")
362
 
363
  # ---- Insert block (LoRA downloads) between lines 268 and 269 ----
364
  # LoRA repo + download the requested LoRA adapters
@@ -396,6 +329,7 @@ print(f"Transition LoRA: {transition_lora_path}")
396
 
397
  print(f"Checkpoint: {checkpoint_path}")
398
  print(f"Spatial upsampler: {spatial_upsampler_path}")
 
399
 
400
  # Initialize pipeline WITH text encoder and optional audio support
401
  # ---- Replace block (pipeline init) lines 275-281 ----
 
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
 
 
270
 
271
  # Download model checkpoints
272
  print("=" * 80)
 
291
  local_dir_use_symlinks=False,
292
  )
293
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
294
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  # ---- Insert block (LoRA downloads) between lines 268 and 269 ----
297
  # LoRA repo + download the requested LoRA adapters
 
329
 
330
  print(f"Checkpoint: {checkpoint_path}")
331
  print(f"Spatial upsampler: {spatial_upsampler_path}")
332
+ print(f"Gemma root: {gemma_root}")
333
 
334
  # Initialize pipeline WITH text encoder and optional audio support
335
  # ---- Replace block (pipeline init) lines 275-281 ----