Update app.py
Browse files
app.py
CHANGED
|
@@ -267,8 +267,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 267 |
# Model repos
|
| 268 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 269 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 270 |
-
GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
|
| 271 |
-
GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
|
| 272 |
|
| 273 |
# Download model checkpoints
|
| 274 |
print("=" * 80)
|
|
@@ -293,72 +291,7 @@ checkpoint_path = hf_hub_download(
|
|
| 293 |
local_dir_use_symlinks=False,
|
| 294 |
)
|
| 295 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
|
| 296 |
-
|
| 297 |
-
print("[Gemma] Setting up abliterated Gemma text encoder...")
|
| 298 |
-
MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
|
| 299 |
-
gemma_root = "/tmp/abliterated_gemma"
|
| 300 |
-
os.makedirs(gemma_root, exist_ok=True)
|
| 301 |
-
|
| 302 |
-
gemma_official_dir = snapshot_download(
|
| 303 |
-
repo_id=GEMMA_REPO,
|
| 304 |
-
ignore_patterns=["*.safetensors", "*.safetensors.index.json"],
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
for fname in os.listdir(gemma_official_dir):
|
| 308 |
-
src = os.path.join(gemma_official_dir, fname)
|
| 309 |
-
dst = os.path.join(gemma_root, fname)
|
| 310 |
-
if os.path.isfile(src) and not fname.endswith(".safetensors") and fname != "model.safetensors.index.json":
|
| 311 |
-
if not os.path.exists(dst):
|
| 312 |
-
os.symlink(src, dst)
|
| 313 |
-
|
| 314 |
-
if os.path.exists(MERGED_WEIGHTS):
|
| 315 |
-
print("[Gemma] Using cached merged weights")
|
| 316 |
-
else:
|
| 317 |
-
abliterated_weights_path = hf_hub_download(
|
| 318 |
-
repo_id=GEMMA_ABLITERATED_REPO,
|
| 319 |
-
filename=GEMMA_ABLITERATED_FILE,
|
| 320 |
-
)
|
| 321 |
-
index_path = hf_hub_download(
|
| 322 |
-
repo_id=GEMMA_REPO,
|
| 323 |
-
filename="model.safetensors.index.json"
|
| 324 |
-
)
|
| 325 |
-
with open(index_path) as f:
|
| 326 |
-
weight_index = json.load(f)
|
| 327 |
-
|
| 328 |
-
vision_keys = {}
|
| 329 |
-
for key, shard in weight_index["weight_map"].items():
|
| 330 |
-
if "vision_tower" in key or "multi_modal_projector" in key:
|
| 331 |
-
vision_keys[key] = shard
|
| 332 |
-
needed_shards = set(vision_keys.values())
|
| 333 |
-
|
| 334 |
-
shard_paths = {}
|
| 335 |
-
for shard_name in needed_shards:
|
| 336 |
-
shard_paths[shard_name] = hf_hub_download(
|
| 337 |
-
repo_id=GEMMA_REPO,
|
| 338 |
-
filename=shard_name
|
| 339 |
-
)
|
| 340 |
-
|
| 341 |
-
_fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
|
| 342 |
-
raw = load_file(abliterated_weights_path)
|
| 343 |
-
merged = {}
|
| 344 |
-
for key, tensor in raw.items():
|
| 345 |
-
t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
|
| 346 |
-
merged[f"language_model.{key}"] = t
|
| 347 |
-
del raw
|
| 348 |
-
|
| 349 |
-
for key, shard_name in vision_keys.items():
|
| 350 |
-
with safe_open(shard_paths[shard_name], framework="pt") as f:
|
| 351 |
-
merged[key] = f.get_tensor(key)
|
| 352 |
-
|
| 353 |
-
save_file(merged, MERGED_WEIGHTS)
|
| 354 |
-
del merged
|
| 355 |
-
gc.collect()
|
| 356 |
-
|
| 357 |
-
weight_link = os.path.join(gemma_root, "model.safetensors")
|
| 358 |
-
if os.path.exists(weight_link):
|
| 359 |
-
os.remove(weight_link)
|
| 360 |
-
os.symlink(MERGED_WEIGHTS, weight_link)
|
| 361 |
-
print(f"[Gemma] Root ready: {gemma_root}")
|
| 362 |
|
| 363 |
# ---- Insert block (LoRA downloads) between lines 268 and 269 ----
|
| 364 |
# LoRA repo + download the requested LoRA adapters
|
|
@@ -396,6 +329,7 @@ print(f"Transition LoRA: {transition_lora_path}")
|
|
| 396 |
|
| 397 |
print(f"Checkpoint: {checkpoint_path}")
|
| 398 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
|
|
|
| 399 |
|
| 400 |
# Initialize pipeline WITH text encoder and optional audio support
|
| 401 |
# ---- Replace block (pipeline init) lines 275-281 ----
|
|
|
|
| 267 |
# Model repos
|
| 268 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 269 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
|
|
|
|
|
|
| 270 |
|
| 271 |
# Download model checkpoints
|
| 272 |
print("=" * 80)
|
|
|
|
| 291 |
local_dir_use_symlinks=False,
|
| 292 |
)
|
| 293 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
|
| 294 |
+
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
# ---- Insert block (LoRA downloads) between lines 268 and 269 ----
|
| 297 |
# LoRA repo + download the requested LoRA adapters
|
|
|
|
| 329 |
|
| 330 |
print(f"Checkpoint: {checkpoint_path}")
|
| 331 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 332 |
+
print(f"Gemma root: {gemma_root}")
|
| 333 |
|
| 334 |
# Initialize pipeline WITH text encoder and optional audio support
|
| 335 |
# ---- Replace block (pipeline init) lines 275-281 ----
|