Spaces:
Paused
Paused
Update aduc_framework/managers/wan_manager.py
Browse files
aduc_framework/managers/wan_manager.py
CHANGED
|
@@ -1,138 +1,174 @@
|
|
| 1 |
-
# aduc_framework/managers/
|
| 2 |
import torch
|
| 3 |
import logging
|
| 4 |
import yaml
|
| 5 |
from PIL import Image
|
| 6 |
-
import numpy as np
|
| 7 |
from typing import List, Optional
|
| 8 |
import sys
|
| 9 |
import os
|
| 10 |
|
| 11 |
# --- INÍCIO DA CORREÇÃO DE IMPORTAÇÃO ---
|
| 12 |
-
# Adiciona o diretório do Wan2.2 ao sys.path para
|
| 13 |
WAN_REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Wan2.2'))
|
| 14 |
if WAN_REPO_PATH not in sys.path:
|
| 15 |
sys.path.insert(0, WAN_REPO_PATH)
|
| 16 |
-
logging.info(f"Adicionado '{WAN_REPO_PATH}' ao sys.path para
|
| 17 |
# --- FIM DA CORREÇÃO DE IMPORTAÇÃO ---
|
| 18 |
|
| 19 |
-
# Ferramentas da nossa arquitetura
|
| 20 |
from ..tools.hardware_manager import hardware_manager
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
from diffusers.
|
| 25 |
-
from diffusers.
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
logger = logging.getLogger(__name__)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
class
|
| 31 |
"""
|
| 32 |
-
Especialista
|
| 33 |
-
|
| 34 |
-
|
| 35 |
"""
|
| 36 |
def __init__(self):
|
| 37 |
self.device = None
|
| 38 |
-
self.
|
| 39 |
-
self.dtype = None
|
| 40 |
self.config = self._load_config()
|
| 41 |
-
|
| 42 |
if self.config:
|
| 43 |
gpus_required = self.config.get('gpus_required', 0)
|
| 44 |
if gpus_required > 0:
|
| 45 |
-
self.device = hardware_manager.allocate_gpus('
|
| 46 |
-
logger.info(f"
|
| 47 |
else:
|
| 48 |
-
|
| 49 |
-
logger.warning("VaeWanManager: Nenhuma GPU dedicada foi alocada.")
|
| 50 |
else:
|
| 51 |
-
logger.warning("Configuração para '
|
| 52 |
|
| 53 |
def _load_config(self):
|
| 54 |
-
"""Carrega a configuração específica deste manager."""
|
| 55 |
try:
|
| 56 |
-
with open("config.yaml", 'r'
|
| 57 |
-
return yaml.safe_load(f).get('specialists', {}).get('
|
| 58 |
except FileNotFoundError:
|
| 59 |
logger.error("config.yaml não encontrado.")
|
| 60 |
return None
|
| 61 |
|
| 62 |
def _lazy_init(self):
|
| 63 |
-
"""Carrega
|
| 64 |
-
if self.
|
| 65 |
return
|
| 66 |
if not self.device or not self.config:
|
| 67 |
-
raise RuntimeError("
|
| 68 |
|
| 69 |
-
logger.info(f"
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
try:
|
| 76 |
-
#
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
-
logger.error(f"
|
| 84 |
-
self.
|
| 85 |
raise e
|
| 86 |
|
| 87 |
-
def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
|
| 88 |
-
"""Converte uma imagem PIL para o formato de tensor 5D esperado pelo VAE de vídeo."""
|
| 89 |
-
from PIL import ImageOps
|
| 90 |
-
img = pil_image.convert("RGB")
|
| 91 |
-
processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
|
| 92 |
-
image_np = np.array(processed_img).astype(np.float32) / 255.0
|
| 93 |
-
|
| 94 |
-
# Converte para (B, C, H, W)
|
| 95 |
-
tensor_4d = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
|
| 96 |
-
tensor_4d_normalized = (tensor_4d * 2.0) - 1.0
|
| 97 |
-
|
| 98 |
-
# Adiciona a dimensão de "frame" para criar um tensor 5D (B, C, F, H, W)
|
| 99 |
-
tensor_5d = tensor_4d_normalized.unsqueeze(2)
|
| 100 |
-
|
| 101 |
-
return tensor_5d
|
| 102 |
-
|
| 103 |
@torch.no_grad()
|
| 104 |
-
def
|
| 105 |
-
"""
|
| 106 |
self._lazy_init()
|
| 107 |
-
if not self.
|
| 108 |
-
raise RuntimeError("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
for img in pil_images:
|
| 112 |
-
# A função de pré-processamento agora retorna o tensor 5D correto
|
| 113 |
-
pixel_tensor_gpu = self._preprocess_pil_image(img, target_resolution).to(self.device, dtype=self.dtype)
|
| 114 |
-
|
| 115 |
-
encoder_output = self.vae.encode(pixel_tensor_gpu)
|
| 116 |
-
latents = retrieve_latents(encoder_output)
|
| 117 |
-
|
| 118 |
-
latents_list.append(latents.cpu())
|
| 119 |
-
|
| 120 |
-
return latents_list
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
"""Decodifica um tensor latente do Wan para o espaço de pixels."""
|
| 125 |
-
self._lazy_init()
|
| 126 |
-
if not self.vae:
|
| 127 |
-
raise RuntimeError("O VAE do WanManager não foi carregado.")
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
pixels = decode_output.sample
|
| 134 |
|
| 135 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# --- Instância Singleton ---
|
| 138 |
-
|
|
|
|
| 1 |
+
# aduc_framework/managers/wan_manager.py (Versão Definitiva Completa)
|
| 2 |
import torch
|
| 3 |
import logging
|
| 4 |
import yaml
|
| 5 |
from PIL import Image
|
|
|
|
| 6 |
from typing import List, Optional
|
| 7 |
import sys
|
| 8 |
import os
|
| 9 |
|
| 10 |
# --- INÍCIO DA CORREÇÃO DE IMPORTAÇÃO ---
|
| 11 |
+
# Adiciona o diretório do Wan2.2 ao sys.path para que o Python o encontre.
|
| 12 |
WAN_REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Wan2.2'))
|
| 13 |
if WAN_REPO_PATH not in sys.path:
|
| 14 |
sys.path.insert(0, WAN_REPO_PATH)
|
| 15 |
+
logging.info(f"Adicionado '{WAN_REPO_PATH}' ao sys.path para importações do WanManager.")
|
| 16 |
# --- FIM DA CORREÇÃO DE IMPORTAÇÃO ---
|
| 17 |
|
| 18 |
+
# Ferramentas da nossa arquitetura ADUC
|
| 19 |
from ..tools.hardware_manager import hardware_manager
|
| 20 |
+
from ..tools.pipeline_patches import apply_aduc_patches
|
| 21 |
+
from ..types import LatentConditioningItem
|
| 22 |
|
| 23 |
+
# Especialistas e modelos necessários
|
| 24 |
+
from .vae_wan_manager import vae_wan_manager_singleton
|
| 25 |
+
#from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
| 26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 27 |
+
#from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
|
| 28 |
+
from transformers import CLIPVisionModel
|
| 29 |
|
|
|
|
| 30 |
|
| 31 |
+
from diffusers import WanImageToVideoPipeline
|
| 32 |
+
#from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 33 |
+
from diffusers.models import WanTransformer3DModel
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
|
| 37 |
+
class WanManager:
|
| 38 |
"""
|
| 39 |
+
Especialista ADUC completo e otimizado para geração de vídeo com Wan2.2.
|
| 40 |
+
Incorpora a fusão do LoRA Lightning para geração de alta velocidade (8-steps)
|
| 41 |
+
e patches customizados para controle temporal preciso.
|
| 42 |
"""
|
| 43 |
def __init__(self):
|
| 44 |
self.device = None
|
| 45 |
+
self.pipe: Optional[WanImageToVideoPipeline] = None
|
|
|
|
| 46 |
self.config = self._load_config()
|
| 47 |
+
|
| 48 |
if self.config:
|
| 49 |
gpus_required = self.config.get('gpus_required', 0)
|
| 50 |
if gpus_required > 0:
|
| 51 |
+
self.device = hardware_manager.allocate_gpus('WanManager', gpus_required)[0]
|
| 52 |
+
logger.info(f"WanManager (Lightning): GPU {self.device} reservada.")
|
| 53 |
else:
|
| 54 |
+
logger.warning("WanManager está desabilitado (gpus_required: 0).")
|
|
|
|
| 55 |
else:
|
| 56 |
+
logger.warning("Configuração para 'wan' não encontrada.")
|
| 57 |
|
| 58 |
def _load_config(self):
|
| 59 |
+
"""Carrega a configuração específica deste manager do arquivo YAML global."""
|
| 60 |
try:
|
| 61 |
+
with open("config.yaml", 'r') as f:
|
| 62 |
+
return yaml.safe_load(f).get('specialists', {}).get('wan', {})
|
| 63 |
except FileNotFoundError:
|
| 64 |
logger.error("config.yaml não encontrado.")
|
| 65 |
return None
|
| 66 |
|
| 67 |
def _lazy_init(self):
|
| 68 |
+
"""Carrega a pipeline, aplica otimizações, funde o LoRA e aplica nosso patch."""
|
| 69 |
+
if self.pipe is not None:
|
| 70 |
return
|
| 71 |
if not self.device or not self.config:
|
| 72 |
+
raise RuntimeError("WanManager não pode ser inicializado.")
|
| 73 |
|
| 74 |
+
logger.info(f"WAN MANAGER ({self.device}): Iniciando carregamento OTIMIZADO do Wan2.2...")
|
| 75 |
|
| 76 |
+
main_model_id = self.config.get("model_id")
|
| 77 |
+
opt_model_id = self.config.get("optimized_model_id")
|
| 78 |
+
lora_repo = self.config.get("lora_repo")
|
| 79 |
+
lora_filename = self.config.get("lora_filename")
|
| 80 |
+
torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
| 81 |
+
|
| 82 |
+
# Define os caminhos locais onde os modelos foram baixados
|
| 83 |
+
main_model_path = f"/app/models/{main_model_id}"
|
| 84 |
+
opt_model_path = f"/app/models/{opt_model_id}"
|
| 85 |
+
lora_path = f"/app/models/loras/{os.path.basename(lora_filename)}"
|
| 86 |
|
| 87 |
try:
|
| 88 |
+
# 1. Requisição do VAE dedicado
|
| 89 |
+
vae_wan_manager_singleton._lazy_init()
|
| 90 |
+
vae = vae_wan_manager_singleton.vae
|
| 91 |
+
if vae is None: raise RuntimeError("Falha ao obter o VAE do vae_wan_manager_singleton.")
|
| 92 |
+
|
| 93 |
+
# 2. Carregamento dos componentes dos caminhos locais
|
| 94 |
+
image_encoder = CLIPVisionModel.from_pretrained(main_model_path, subfolder="image_encoder")
|
| 95 |
+
transformer = WanTransformer3DModel.from_pretrained(opt_model_path, subfolder='transformer', torch_dtype=torch_dtype)
|
| 96 |
+
transformer_2 = WanTransformer3DModel.from_pretrained(opt_model_path, subfolder='transformer_2', torch_dtype=torch_dtype)
|
| 97 |
+
|
| 98 |
+
# 3. Montagem da pipeline base
|
| 99 |
+
self.pipe = WanImageToVideoPipeline.from_pretrained(main_model_path, vae=vae, image_encoder=image_encoder, transformer=transformer, transformer_2=transformer_2, torch_dtype=torch_dtype)
|
| 100 |
+
|
| 101 |
+
# 4. Ajuste do Scheduler
|
| 102 |
+
self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.pipe.scheduler.config, shift=32.0)
|
| 103 |
+
|
| 104 |
+
# 5. Fusão do LoRA Lightning
|
| 105 |
+
logger.info(f"WAN MANAGER ({self.device}): Carregando e fundindo LoRA Lightning de '{lora_path}'...")
|
| 106 |
+
self.pipe.load_lora_weights(os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name="lightx2v")
|
| 107 |
+
self.pipe.load_lora_weights(os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name="lightx2v_2", load_into_transformer_2=True)
|
| 108 |
+
self.pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
|
| 109 |
+
self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
|
| 110 |
+
self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
|
| 111 |
+
self.pipe.unload_lora_weights()
|
| 112 |
+
logger.info(f"WAN MANAGER ({self.device}): LoRA Lightning fundido.")
|
| 113 |
+
|
| 114 |
+
# 6. Aplicação do nosso patch ADUC
|
| 115 |
+
apply_aduc_patches()
|
| 116 |
+
|
| 117 |
+
# 7. Finalização e envio para a GPU
|
| 118 |
+
self.pipe.to(self.device)
|
| 119 |
+
logger.info(f"WAN MANAGER ({self.device}): Pipeline Wan2.2 OTIMIZADA, MODIFICADA e pronta na VRAM.")
|
| 120 |
+
|
| 121 |
except Exception as e:
|
| 122 |
+
logger.error(f"WAN MANAGER: Falha CRÍTICA ao carregar a pipeline: {e}", exc_info=True)
|
| 123 |
+
self.pipe = None
|
| 124 |
raise e
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
@torch.no_grad()
|
| 127 |
+
def generate_latent_fragment(self, **kwargs) -> tuple[torch.Tensor, None]:
|
| 128 |
+
"""Gera um fragmento de vídeo no espaço latente. A interface é IDÊNTICA à do LtxManager."""
|
| 129 |
self._lazy_init()
|
| 130 |
+
if not self.pipe:
|
| 131 |
+
raise RuntimeError("A pipeline do WanManager não está disponível.")
|
| 132 |
+
|
| 133 |
+
conditioning_items: List[LatentConditioningItem] = kwargs.get("conditioning_items_data", [])
|
| 134 |
+
if not conditioning_items:
|
| 135 |
+
raise ValueError("WanManager no modo ADUC requer 'conditioning_items_data'.")
|
| 136 |
+
|
| 137 |
+
pipeline_params = {
|
| 138 |
+
"prompt": kwargs.get("motion_prompt", ""),
|
| 139 |
+
"negative_prompt": kwargs.get("negative_prompt", "static, disfigured, low quality"),
|
| 140 |
+
"height": kwargs.get("height", self.config.get("default_height", 480)),
|
| 141 |
+
"width": kwargs.get("width", self.config.get("default_width", 832)),
|
| 142 |
+
"num_frames": kwargs.get("video_total_frames", self.config.get("default_frames", 81)),
|
| 143 |
+
"guidance_scale": kwargs.get("guidance_scale", self.config.get("guidance_scale", 1.0)),
|
| 144 |
+
"guidance_scale_2": kwargs.get("guidance_scale_2", self.config.get("guidance_scale_2", 1.0)),
|
| 145 |
+
"num_inference_steps": kwargs.get("num_inference_steps", self.config.get("inference_steps", 8)),
|
| 146 |
+
"generator": torch.Generator(device=self.device).manual_seed(int(torch.randint(0, 100000, (1,)).item())),
|
| 147 |
+
}
|
| 148 |
|
| 149 |
+
logger.info(f"WAN MANAGER (Lightning): Gerando fragmento com {pipeline_params['num_inference_steps']} passos.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
first_latent_for_encoding = conditioning_items[0].latent_tensor
|
| 152 |
+
pil_image_for_encoder = self._decode_latent_to_pil(first_latent_for_encoding)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
output = self.pipe(
|
| 155 |
+
image=pil_image_for_encoder,
|
| 156 |
+
conditioning_items=conditioning_items,
|
| 157 |
+
output_type='latent',
|
| 158 |
+
**pipeline_params
|
| 159 |
+
)
|
| 160 |
|
| 161 |
+
video_latents = output.frames
|
| 162 |
+
logger.info(f"WAN MANAGER (Lightning): Fragmento latente gerado. Shape: {video_latents.shape}")
|
|
|
|
| 163 |
|
| 164 |
+
return video_latents.cpu(), None
|
| 165 |
+
|
| 166 |
+
def _decode_latent_to_pil(self, latent_tensor: torch.Tensor) -> Image.Image:
|
| 167 |
+
"""Função auxiliar para decodificar um latente em uma imagem PIL usando o VAE dedicado."""
|
| 168 |
+
pixel_tensor = vae_wan_manager_singleton.decode(latent_tensor.unsqueeze(0))
|
| 169 |
+
pixel_tensor = (pixel_tensor / 2 + 0.5).clamp(0, 1)
|
| 170 |
+
numpy_image = (pixel_tensor.cpu().permute(0, 2, 3, 4, 1).squeeze(0).squeeze(0) * 255).byte().numpy()
|
| 171 |
+
return Image.fromarray(numpy_image)
|
| 172 |
|
| 173 |
# --- Instância Singleton ---
|
| 174 |
+
wan_manager_singleton = WanManager()
|