x2XcarleX2x commited on
Commit
3d73884
·
verified ·
1 Parent(s): 2ce5c5d

Update aduc_framework/managers/wan_manager.py

Browse files
Files changed (1) hide show
  1. aduc_framework/managers/wan_manager.py +119 -83
aduc_framework/managers/wan_manager.py CHANGED
@@ -1,138 +1,174 @@
1
- # aduc_framework/managers/vae_wan_manager.py (Versão Definitiva Completa)
2
  import torch
3
  import logging
4
  import yaml
5
  from PIL import Image
6
- import numpy as np
7
  from typing import List, Optional
8
  import sys
9
  import os
10
 
11
  # --- INÍCIO DA CORREÇÃO DE IMPORTAÇÃO ---
12
- # Adiciona o diretório do Wan2.2 ao sys.path para encontrar módulos customizados.
13
  WAN_REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Wan2.2'))
14
  if WAN_REPO_PATH not in sys.path:
15
  sys.path.insert(0, WAN_REPO_PATH)
16
- logging.info(f"Adicionado '{WAN_REPO_PATH}' ao sys.path para o VaeWanManager.")
17
  # --- FIM DA CORREÇÃO DE IMPORTAÇÃO ---
18
 
19
- # Ferramentas da nossa arquitetura
20
  from ..tools.hardware_manager import hardware_manager
 
 
21
 
22
- # --- IMPORTAÇÕES CORRIGIDAS ---
23
- # Importa a classe de VAE customizada do Wan2.2 e a função auxiliar oficial.
24
- from diffusers.models import AutoencoderKLWan
25
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
 
 
26
 
27
- logger = logging.getLogger(__name__)
28
 
 
 
 
 
 
29
 
30
- class VaeWanManager:
31
  """
32
- Especialista VAE dedicado e persistente para a pipeline Wan2.2.
33
- Gerencia o ciclo de vida do AutoencoderKLWan em uma GPU dedicada,
34
- garantindo a tradução correta entre o espaço de pixels e o espaço latente.
35
  """
36
  def __init__(self):
37
  self.device = None
38
- self.vae: Optional[AutoencoderKLWan] = None
39
- self.dtype = None
40
  self.config = self._load_config()
41
-
42
  if self.config:
43
  gpus_required = self.config.get('gpus_required', 0)
44
  if gpus_required > 0:
45
- self.device = hardware_manager.allocate_gpus('VaeWanManager', gpus_required)[0]
46
- logger.info(f"VaeWanManager: GPU dedicada '{self.device}' reservada.")
47
  else:
48
- self.device = torch.device('cpu')
49
- logger.warning("VaeWanManager: Nenhuma GPU dedicada foi alocada.")
50
  else:
51
- logger.warning("Configuração para 'vae_wan' não encontrada em config.yaml.")
52
 
53
  def _load_config(self):
54
- """Carrega a configuração específica deste manager."""
55
  try:
56
- with open("config.yaml", 'r', encoding='utf-8') as f:
57
- return yaml.safe_load(f).get('specialists', {}).get('vae_wan', {})
58
  except FileNotFoundError:
59
  logger.error("config.yaml não encontrado.")
60
  return None
61
 
62
  def _lazy_init(self):
63
- """Carrega o modelo VAE do Wan2.2 para a VRAM no primeiro uso."""
64
- if self.vae is not None:
65
  return
66
  if not self.device or not self.config:
67
- raise RuntimeError("VaeWanManager não pode ser inicializado.")
68
 
69
- logger.info(f"VAE-WAN MANAGER ({self.device}): Carregando VAE do Wan2.2...")
70
 
71
- model_id = self.config.get("model_id")
72
- local_model_path = f"/app/models/{model_id}" # Carrega do nosso diretório local
73
- self.dtype = torch.float32 # VAEs são mais estáveis em FP32
 
 
 
 
 
 
 
74
 
75
  try:
76
- # Carrega a classe CORRETA (AutoencoderKLWan) do caminho LOCAL.
77
- self.vae = AutoencoderKLWan.from_pretrained(
78
- local_model_path, subfolder="vae", torch_dtype=self.dtype
79
- ).to(self.device)
80
- self.vae.eval()
81
- logger.info(f"VAE-WAN MANAGER ({self.device}): VAE do Wan2.2 pronto e 'quente' na VRAM.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
- logger.error(f"VAE-WAN MANAGER: Falha CRÍTICA ao carregar o VAE: {e}", exc_info=True)
84
- self.vae = None
85
  raise e
86
 
87
- def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
88
- """Converte uma imagem PIL para o formato de tensor 5D esperado pelo VAE de vídeo."""
89
- from PIL import ImageOps
90
- img = pil_image.convert("RGB")
91
- processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
92
- image_np = np.array(processed_img).astype(np.float32) / 255.0
93
-
94
- # Converte para (B, C, H, W)
95
- tensor_4d = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
96
- tensor_4d_normalized = (tensor_4d * 2.0) - 1.0
97
-
98
- # Adiciona a dimensão de "frame" para criar um tensor 5D (B, C, F, H, W)
99
- tensor_5d = tensor_4d_normalized.unsqueeze(2)
100
-
101
- return tensor_5d
102
-
103
  @torch.no_grad()
104
- def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]:
105
- """Codifica um lote de imagens PIL para o espaço latente do Wan."""
106
  self._lazy_init()
107
- if not self.vae:
108
- raise RuntimeError("O VAE do WanManager não foi carregado.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- latents_list = []
111
- for img in pil_images:
112
- # A função de pré-processamento agora retorna o tensor 5D correto
113
- pixel_tensor_gpu = self._preprocess_pil_image(img, target_resolution).to(self.device, dtype=self.dtype)
114
-
115
- encoder_output = self.vae.encode(pixel_tensor_gpu)
116
- latents = retrieve_latents(encoder_output)
117
-
118
- latents_list.append(latents.cpu())
119
-
120
- return latents_list
121
 
122
- @torch.no_grad()
123
- def decode(self, latent_tensor: torch.Tensor) -> torch.Tensor:
124
- """Decodifica um tensor latente do Wan para o espaço de pixels."""
125
- self._lazy_init()
126
- if not self.vae:
127
- raise RuntimeError("O VAE do WanManager não foi carregado.")
128
 
129
- latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
 
 
 
 
 
130
 
131
- # Acessa a saída através do atributo .sample
132
- decode_output = self.vae.decode(latent_tensor_gpu)
133
- pixels = decode_output.sample
134
 
135
- return pixels.cpu()
 
 
 
 
 
 
 
136
 
137
  # --- Instância Singleton ---
138
- vae_wan_manager_singleton = VaeWanManager()
 
1
+ # aduc_framework/managers/wan_manager.py (Versão Definitiva Completa)
2
  import torch
3
  import logging
4
  import yaml
5
  from PIL import Image
 
6
  from typing import List, Optional
7
  import sys
8
  import os
9
 
10
  # --- INÍCIO DA CORREÇÃO DE IMPORTAÇÃO ---
11
+ # Adiciona o diretório do Wan2.2 ao sys.path para que o Python o encontre.
12
  WAN_REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Wan2.2'))
13
  if WAN_REPO_PATH not in sys.path:
14
  sys.path.insert(0, WAN_REPO_PATH)
15
+ logging.info(f"Adicionado '{WAN_REPO_PATH}' ao sys.path para importações do WanManager.")
16
  # --- FIM DA CORREÇÃO DE IMPORTAÇÃO ---
17
 
18
+ # Ferramentas da nossa arquitetura ADUC
19
  from ..tools.hardware_manager import hardware_manager
20
+ from ..tools.pipeline_patches import apply_aduc_patches
21
+ from ..types import LatentConditioningItem
22
 
23
+ # Especialistas e modelos necessários
24
+ from .vae_wan_manager import vae_wan_manager_singleton
25
+ #from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ #from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
28
+ from transformers import CLIPVisionModel
29
 
 
30
 
31
+ from diffusers import WanImageToVideoPipeline
32
+ #from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.models import WanTransformer3DModel
34
+
35
+ logger = logging.getLogger(__name__)
36
 
37
+ class WanManager:
38
  """
39
+ Especialista ADUC completo e otimizado para geração de vídeo com Wan2.2.
40
+ Incorpora a fusão do LoRA Lightning para geração de alta velocidade (8-steps)
41
+ e patches customizados para controle temporal preciso.
42
  """
43
  def __init__(self):
44
  self.device = None
45
+ self.pipe: Optional[WanImageToVideoPipeline] = None
 
46
  self.config = self._load_config()
47
+
48
  if self.config:
49
  gpus_required = self.config.get('gpus_required', 0)
50
  if gpus_required > 0:
51
+ self.device = hardware_manager.allocate_gpus('WanManager', gpus_required)[0]
52
+ logger.info(f"WanManager (Lightning): GPU {self.device} reservada.")
53
  else:
54
+ logger.warning("WanManager está desabilitado (gpus_required: 0).")
 
55
  else:
56
+ logger.warning("Configuração para 'wan' não encontrada.")
57
 
58
  def _load_config(self):
59
+ """Carrega a configuração específica deste manager do arquivo YAML global."""
60
  try:
61
+ with open("config.yaml", 'r') as f:
62
+ return yaml.safe_load(f).get('specialists', {}).get('wan', {})
63
  except FileNotFoundError:
64
  logger.error("config.yaml não encontrado.")
65
  return None
66
 
67
  def _lazy_init(self):
68
+ """Carrega a pipeline, aplica otimizações, funde o LoRA e aplica nosso patch."""
69
+ if self.pipe is not None:
70
  return
71
  if not self.device or not self.config:
72
+ raise RuntimeError("WanManager não pode ser inicializado.")
73
 
74
+ logger.info(f"WAN MANAGER ({self.device}): Iniciando carregamento OTIMIZADO do Wan2.2...")
75
 
76
+ main_model_id = self.config.get("model_id")
77
+ opt_model_id = self.config.get("optimized_model_id")
78
+ lora_repo = self.config.get("lora_repo")
79
+ lora_filename = self.config.get("lora_filename")
80
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
81
+
82
+ # Define os caminhos locais onde os modelos foram baixados
83
+ main_model_path = f"/app/models/{main_model_id}"
84
+ opt_model_path = f"/app/models/{opt_model_id}"
85
+ lora_path = f"/app/models/loras/{os.path.basename(lora_filename)}"
86
 
87
  try:
88
+ # 1. Requisição do VAE dedicado
89
+ vae_wan_manager_singleton._lazy_init()
90
+ vae = vae_wan_manager_singleton.vae
91
+ if vae is None: raise RuntimeError("Falha ao obter o VAE do vae_wan_manager_singleton.")
92
+
93
+ # 2. Carregamento dos componentes dos caminhos locais
94
+ image_encoder = CLIPVisionModel.from_pretrained(main_model_path, subfolder="image_encoder")
95
+ transformer = WanTransformer3DModel.from_pretrained(opt_model_path, subfolder='transformer', torch_dtype=torch_dtype)
96
+ transformer_2 = WanTransformer3DModel.from_pretrained(opt_model_path, subfolder='transformer_2', torch_dtype=torch_dtype)
97
+
98
+ # 3. Montagem da pipeline base
99
+ self.pipe = WanImageToVideoPipeline.from_pretrained(main_model_path, vae=vae, image_encoder=image_encoder, transformer=transformer, transformer_2=transformer_2, torch_dtype=torch_dtype)
100
+
101
+ # 4. Ajuste do Scheduler
102
+ self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.pipe.scheduler.config, shift=32.0)
103
+
104
+ # 5. Fusão do LoRA Lightning
105
+ logger.info(f"WAN MANAGER ({self.device}): Carregando e fundindo LoRA Lightning de '{lora_path}'...")
106
+ self.pipe.load_lora_weights(os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name="lightx2v")
107
+ self.pipe.load_lora_weights(os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name="lightx2v_2", load_into_transformer_2=True)
108
+ self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
109
+ self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
110
+ self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
111
+ self.pipe.unload_lora_weights()
112
+ logger.info(f"WAN MANAGER ({self.device}): LoRA Lightning fundido.")
113
+
114
+ # 6. Aplicação do nosso patch ADUC
115
+ apply_aduc_patches()
116
+
117
+ # 7. Finalização e envio para a GPU
118
+ self.pipe.to(self.device)
119
+ logger.info(f"WAN MANAGER ({self.device}): Pipeline Wan2.2 OTIMIZADA, MODIFICADA e pronta na VRAM.")
120
+
121
  except Exception as e:
122
+ logger.error(f"WAN MANAGER: Falha CRÍTICA ao carregar a pipeline: {e}", exc_info=True)
123
+ self.pipe = None
124
  raise e
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @torch.no_grad()
127
+ def generate_latent_fragment(self, **kwargs) -> tuple[torch.Tensor, None]:
128
+ """Gera um fragmento de vídeo no espaço latente. A interface é IDÊNTICA à do LtxManager."""
129
  self._lazy_init()
130
+ if not self.pipe:
131
+ raise RuntimeError("A pipeline do WanManager não está disponível.")
132
+
133
+ conditioning_items: List[LatentConditioningItem] = kwargs.get("conditioning_items_data", [])
134
+ if not conditioning_items:
135
+ raise ValueError("WanManager no modo ADUC requer 'conditioning_items_data'.")
136
+
137
+ pipeline_params = {
138
+ "prompt": kwargs.get("motion_prompt", ""),
139
+ "negative_prompt": kwargs.get("negative_prompt", "static, disfigured, low quality"),
140
+ "height": kwargs.get("height", self.config.get("default_height", 480)),
141
+ "width": kwargs.get("width", self.config.get("default_width", 832)),
142
+ "num_frames": kwargs.get("video_total_frames", self.config.get("default_frames", 81)),
143
+ "guidance_scale": kwargs.get("guidance_scale", self.config.get("guidance_scale", 1.0)),
144
+ "guidance_scale_2": kwargs.get("guidance_scale_2", self.config.get("guidance_scale_2", 1.0)),
145
+ "num_inference_steps": kwargs.get("num_inference_steps", self.config.get("inference_steps", 8)),
146
+ "generator": torch.Generator(device=self.device).manual_seed(int(torch.randint(0, 100000, (1,)).item())),
147
+ }
148
 
149
+ logger.info(f"WAN MANAGER (Lightning): Gerando fragmento com {pipeline_params['num_inference_steps']} passos.")
 
 
 
 
 
 
 
 
 
 
150
 
151
+ first_latent_for_encoding = conditioning_items[0].latent_tensor
152
+ pil_image_for_encoder = self._decode_latent_to_pil(first_latent_for_encoding)
 
 
 
 
153
 
154
+ output = self.pipe(
155
+ image=pil_image_for_encoder,
156
+ conditioning_items=conditioning_items,
157
+ output_type='latent',
158
+ **pipeline_params
159
+ )
160
 
161
+ video_latents = output.frames
162
+ logger.info(f"WAN MANAGER (Lightning): Fragmento latente gerado. Shape: {video_latents.shape}")
 
163
 
164
+ return video_latents.cpu(), None
165
+
166
+ def _decode_latent_to_pil(self, latent_tensor: torch.Tensor) -> Image.Image:
167
+ """Função auxiliar para decodificar um latente em uma imagem PIL usando o VAE dedicado."""
168
+ pixel_tensor = vae_wan_manager_singleton.decode(latent_tensor.unsqueeze(0))
169
+ pixel_tensor = (pixel_tensor / 2 + 0.5).clamp(0, 1)
170
+ numpy_image = (pixel_tensor.cpu().permute(0, 2, 3, 4, 1).squeeze(0).squeeze(0) * 255).byte().numpy()
171
+ return Image.fromarray(numpy_image)
172
 
173
  # --- Instância Singleton ---
174
+ wan_manager_singleton = WanManager()