eeuuia commited on
Commit
7edcb31
·
verified ·
1 Parent(s): df14d1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -15,6 +15,9 @@ import cv2
15
  import shutil
16
  import glob
17
  from pathlib import Path
 
 
 
18
 
19
  import warnings
20
  import logging
@@ -34,33 +37,54 @@ FPS = 24
34
  dtype = torch.bfloat16
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
 
 
37
  base_model_repo = "Lightricks/LTX-Video"
38
- print(f"Carregando a arquitetura completa da pipeline de {base_model_repo}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  pipeline = LTXConditionPipeline.from_pretrained(
40
  base_model_repo,
 
41
  torch_dtype=dtype,
42
  cache_dir=os.getenv("HF_HOME_CACHE"),
43
  token=os.getenv("HF_TOKEN"),
44
  )
45
 
46
- # 2. Definir a URL para o arquivo de pesos FP8 que contém apenas o TRANSFORMER.
47
- fp8_transformer_weights_url = "https://huggingface.co/Lightricks/LTX-Video/ltxv-13b-0.9.8-distilled-fp8.safetensors"
48
- print(f"Sobrescrevendo pesos do Transformer com o arquivo FP8 de: {fp8_transformer_weights_url}")
49
-
50
- pipeline.load_lora_weights(fp8_transformer_weights_url, from_diffusers=True)
51
-
52
  print("Carregando upsampler...")
53
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
54
  "Lightricks/ltxv-spatial-upscaler-0.9.7",
55
  cache_dir=os.getenv("HF_HOME_CACHE"),
56
- vae=pipeline.vae,
57
  torch_dtype=dtype
58
  )
59
 
60
- print("Movendo modelos para o dispositivo...")
61
- pipeline.to(device)
62
- pipe_upsample.to(device)
63
- pipeline.vae.enable_tiling()
 
 
 
 
 
 
 
 
64
 
65
  current_dir = Path(__file__).parent
66
 
 
15
  import shutil
16
  import glob
17
  from pathlib import Path
18
+ from diffusers.models.modeling_utils import AutoModel
19
+ from diffusers.models.group_offload import apply_group_offloading
20
+
21
 
22
  import warnings
23
  import logging
 
37
  dtype = torch.bfloat16
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
+
41
+ # 1. Definir o repositório base
42
  base_model_repo = "Lightricks/LTX-Video"
43
+
44
+ # 2. Carregar o Transformer separadamente para aplicar o casting FP8
45
+ print("Carregando Transformer para otimização FP8...")
46
+ transformer = AutoModel.from_pretrained(
47
+ base_model_repo,
48
+ subfolder="transformer",
49
+ torch_dtype=dtype
50
+ )
51
+ # Habilita a conversão dinâmica para FP8 (requer hardware compatível para funcionar)
52
+ print("Habilitando layerwise casting para FP8...")
53
+ transformer.enable_layerwise_casting(
54
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=dtype
55
+ )
56
+
57
+ # 3. Carregar a pipeline completa, injetando o Transformer já otimizado
58
+ print(f"Carregando a arquitetura da pipeline de {base_model_repo}...")
59
  pipeline = LTXConditionPipeline.from_pretrained(
60
  base_model_repo,
61
+ transformer=transformer, # Injeta o transformer otimizado
62
  torch_dtype=dtype,
63
  cache_dir=os.getenv("HF_HOME_CACHE"),
64
  token=os.getenv("HF_TOKEN"),
65
  )
66
 
67
+ # 4. Carregar o upsampler (seu repositório é separado e está correto)
 
 
 
 
 
68
  print("Carregando upsampler...")
69
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
70
  "Lightricks/ltxv-spatial-upscaler-0.9.7",
71
  cache_dir=os.getenv("HF_HOME_CACHE"),
72
+ vae=pipeline.vae,
73
  torch_dtype=dtype
74
  )
75
 
76
+
77
+ # 5. Aplicar o descarregamento de grupos para economizar VRAM
78
+ print("Aplicando otimizações de group-offloading para economizar VRAM...")
79
+ onload_device = torch.device("cuda")
80
+ offload_device = torch.device("cpu")
81
+ # O Transformer já tem um método integrado
82
+ pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
83
+ # Para os outros componentes, usamos a função auxiliar
84
+ apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
85
+ apply_group_offloading(pipeline.vae, onload_device=onload_device, offload_type="leaf_level")
86
+
87
+
88
 
89
  current_dir = Path(__file__).parent
90