x2XcarleX2x commited on
Commit
def1a3e
·
verified ·
1 Parent(s): f93bb97

Update aduc_framework/managers/wan_manager.py

Browse files
aduc_framework/managers/wan_manager.py CHANGED
@@ -3,7 +3,7 @@
3
  import os
4
  import tempfile
5
  import random
6
- from typing import List, Any
7
 
8
  import numpy as np
9
  import torch
@@ -17,13 +17,13 @@ from diffusers.utils.export_utils import export_to_video
17
 
18
  class WanManager:
19
  """
20
- Serviço responsável por:
21
- - Carregar a pipeline Wan I2V com dois transformadores (alto/baixo ruído).
22
- - Aplicar e fundir LoRA Lightning para geração rápida (8 passos).
23
- - Processar imagens e gerar vídeo a partir de uma lista images_condition_items.
24
  """
25
 
26
- # Constantes espelhadas da UI
27
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
28
 
29
  MAX_DIMENSION = 832
@@ -44,7 +44,7 @@ class WanManager:
44
  def __init__(self) -> None:
45
  print("Loading models into memory. This may take a few minutes...")
46
 
47
- # Carrega a pipeline principal com dois transformadores
48
  self.pipe = WanImageToVideoPipeline.from_pretrained(
49
  self.MODEL_ID,
50
  transformer=WanTransformer3DModel.from_pretrained(
@@ -62,12 +62,12 @@ class WanManager:
62
  torch_dtype=torch.bfloat16,
63
  )
64
 
65
- # Scheduler FlowMatch Euler com shift igual ao do app
66
  self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
67
  self.pipe.scheduler.config, shift=32.0
68
  )
69
 
70
- # Fusão do LoRA Lightning (dois adaptadores: transformer e transformer_2)
71
  print("Applying 8-step Lightning LoRA...")
72
  try:
73
  self.pipe.load_lora_weights(
@@ -88,7 +88,7 @@ class WanManager:
88
  self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
89
  self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
90
 
91
- # Após a fusão, descarta adaptadores da memória
92
  self.pipe.unload_lora_weights()
93
  print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
94
  except Exception as e:
@@ -96,15 +96,9 @@ class WanManager:
96
 
97
  print("All models loaded. Service is ready.")
98
 
99
- # ===== Utilidades de imagem (espelho da UI) =====
100
 
101
  def process_image_for_video(self, image: Image.Image) -> Image.Image:
102
- """
103
- Reamostra a imagem respeitando:
104
- - Mín/Máx dimensões
105
- - Múltiplo de 16
106
- - Caso quadrada, força SQUARE_SIZE
107
- """
108
  width, height = image.size
109
  if width == height:
110
  return image.resize((self.SQUARE_SIZE, self.SQUARE_SIZE), Image.Resampling.LANCZOS)
@@ -124,7 +118,7 @@ class WanManager:
124
  new_width *= scale
125
  new_height *= scale
126
 
127
- # Múltiplo e mínimos finais
128
  final_width = int(round(new_width / self.DIMENSION_MULTIPLE) * self.DIMENSION_MULTIPLE)
129
  final_height = int(round(new_height / self.DIMENSION_MULTIPLE) * self.DIMENSION_MULTIPLE)
130
 
@@ -134,9 +128,6 @@ class WanManager:
134
  return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
135
 
136
  def resize_and_crop_to_match(self, target_image: Image.Image, reference_image: Image.Image) -> Image.Image:
137
- """
138
- Redimensiona e faz center-crop para igualar (W,H) da imagem de referência.
139
- """
140
  ref_width, ref_height = reference_image.size
141
  target_width, target_height = target_image.size
142
  scale = max(ref_width / target_width, ref_height / target_height)
@@ -145,63 +136,55 @@ class WanManager:
145
  left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
146
  return resized.crop((left, top, left + ref_width, top + ref_height))
147
 
148
- # ===== API principal =====
149
 
150
  def generate_video_from_conditions(
151
  self,
152
  images_condition_items: List[List[Any]], # [[patch(Image), frame(int|str), peso(float)], ...]
153
  prompt: str,
154
- negative_prompt: str,
155
  duration_seconds: float,
156
  steps: int,
157
  guidance_scale: float,
158
  guidance_scale_2: float,
159
  seed: int,
160
  randomize_seed: bool,
 
161
  ):
162
  """
163
- Usos atuais:
164
- - Usa SOMENTE o primeiro item como imagem inicial (image)
165
- - Usa SOMENTE o último item como last_image (endpoint)
166
- - Mantém todo o restante do contrato da pipeline i2v
167
  """
168
-
169
  if not images_condition_items or len(images_condition_items) < 2:
170
- raise ValueError("Forneça ao menos dois itens em images_condition_items (início e fim).")
171
 
172
  first_item = images_condition_items[0]
173
  last_item = images_condition_items[-1]
174
 
175
- # Estrutura: [patch, frame, peso]; por ora só o patch é utilizado.
176
  start_image = first_item[0]
177
  end_image = last_item[0]
178
  if start_image is None or end_image is None:
179
  raise ValueError("As imagens inicial e final não podem ser vazias.")
 
 
180
 
181
- if not isinstance(start_image, Image.Image):
182
- raise TypeError("O 'patch' do primeiro item deve ser uma PIL.Image.")
183
- if not isinstance(end_image, Image.Image):
184
- raise TypeError("O 'patch' do último item deve ser uma PIL.Image.")
185
-
186
- # Pré-processamento idêntico ao da UI
187
  processed_start = self.process_image_for_video(start_image)
188
  processed_end = self.resize_and_crop_to_match(end_image, processed_start)
189
  target_height, target_width = processed_start.height, processed_start.width
190
 
191
- # Frames do vídeo
192
  num_frames = int(round(duration_seconds * self.FIXED_FPS))
193
  num_frames = int(np.clip(num_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL))
194
 
195
- # Semente
196
  current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed)
197
  generator = torch.Generator().manual_seed(current_seed)
198
 
199
- # Chamada direta da pipeline (image/last_image)
200
  result = self.pipe(
201
  image=processed_start,
202
  last_image=processed_end,
203
  prompt=prompt,
204
- negative_prompt=negative_prompt,
205
  height=target_height,
206
  width=target_width,
207
  num_frames=num_frames,
@@ -209,11 +192,11 @@ class WanManager:
209
  guidance_scale_2=float(guidance_scale_2),
210
  num_inference_steps=int(steps),
211
  generator=generator,
 
212
  )
213
 
214
  frames = result.frames[0]
215
 
216
- # Exporta para vídeo temporário
217
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
218
  video_path = tmp.name
219
  export_to_video(frames, video_path, fps=self.FIXED_FPS)
 
3
  import os
4
  import tempfile
5
  import random
6
+ from typing import List, Any, Optional, Union
7
 
8
  import numpy as np
9
  import torch
 
17
 
18
  class WanManager:
19
  """
20
+ Serviço que encapsula:
21
+ - Carregamento da pipeline Wan I2V com dois transformadores (alto/baixo ruído).
22
+ - Fusão da LoRA Lightning para 8 passos rápidos.
23
+ - Pré-processamento de imagens e geração de vídeo a partir de images_condition_items.
24
  """
25
 
26
+ # Constantes alinhadas ao app
27
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
28
 
29
  MAX_DIMENSION = 832
 
44
  def __init__(self) -> None:
45
  print("Loading models into memory. This may take a few minutes...")
46
 
47
+ # Pipeline com dois transformadores (bf16 + device_map='auto')
48
  self.pipe = WanImageToVideoPipeline.from_pretrained(
49
  self.MODEL_ID,
50
  transformer=WanTransformer3DModel.from_pretrained(
 
62
  torch_dtype=torch.bfloat16,
63
  )
64
 
65
+ # Scheduler FlowMatch Euler (shift = 32.0)
66
  self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
67
  self.pipe.scheduler.config, shift=32.0
68
  )
69
 
70
+ # Fusão da LoRA Lightning (dois adaptadores, um por transformer)
71
  print("Applying 8-step Lightning LoRA...")
72
  try:
73
  self.pipe.load_lora_weights(
 
88
  self.pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
89
  self.pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
90
 
91
+ # Libera adaptadores após a fusão
92
  self.pipe.unload_lora_weights()
93
  print("Lightning LoRA successfully fused. Model is ready for fast 8-step generation.")
94
  except Exception as e:
 
96
 
97
  print("All models loaded. Service is ready.")
98
 
99
+ # ============ Utilidades de imagem ============
100
 
101
  def process_image_for_video(self, image: Image.Image) -> Image.Image:
 
 
 
 
 
 
102
  width, height = image.size
103
  if width == height:
104
  return image.resize((self.SQUARE_SIZE, self.SQUARE_SIZE), Image.Resampling.LANCZOS)
 
118
  new_width *= scale
119
  new_height *= scale
120
 
121
+ # Múltiplo de 16 e mínimos finais
122
  final_width = int(round(new_width / self.DIMENSION_MULTIPLE) * self.DIMENSION_MULTIPLE)
123
  final_height = int(round(new_height / self.DIMENSION_MULTIPLE) * self.DIMENSION_MULTIPLE)
124
 
 
128
  return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
129
 
130
  def resize_and_crop_to_match(self, target_image: Image.Image, reference_image: Image.Image) -> Image.Image:
 
 
 
131
  ref_width, ref_height = reference_image.size
132
  target_width, target_height = target_image.size
133
  scale = max(ref_width / target_width, ref_height / target_height)
 
136
  left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
137
  return resized.crop((left, top, left + ref_width, top + ref_height))
138
 
139
+ # ============ API principal ============
140
 
141
  def generate_video_from_conditions(
142
  self,
143
  images_condition_items: List[List[Any]], # [[patch(Image), frame(int|str), peso(float)], ...]
144
  prompt: str,
145
+ negative_prompt: Optional[str],
146
  duration_seconds: float,
147
  steps: int,
148
  guidance_scale: float,
149
  guidance_scale_2: float,
150
  seed: int,
151
  randomize_seed: bool,
152
+ output_type: str = "np",
153
  ):
154
  """
155
+ Usa apenas:
156
+ - Primeiro item como imagem inicial (image)
157
+ - Último item como last_image (endpoint)
158
+ Mantém todo o restante do contrato i2v.
159
  """
 
160
  if not images_condition_items or len(images_condition_items) < 2:
161
+ raise ValueError("Forneça ao menos dois itens (início e fim).")
162
 
163
  first_item = images_condition_items[0]
164
  last_item = images_condition_items[-1]
165
 
 
166
  start_image = first_item[0]
167
  end_image = last_item[0]
168
  if start_image is None or end_image is None:
169
  raise ValueError("As imagens inicial e final não podem ser vazias.")
170
+ if not isinstance(start_image, Image.Image) or not isinstance(end_image, Image.Image):
171
+ raise TypeError("Os 'patches' devem ser PIL.Image.")
172
 
 
 
 
 
 
 
173
  processed_start = self.process_image_for_video(start_image)
174
  processed_end = self.resize_and_crop_to_match(end_image, processed_start)
175
  target_height, target_width = processed_start.height, processed_start.width
176
 
 
177
  num_frames = int(round(duration_seconds * self.FIXED_FPS))
178
  num_frames = int(np.clip(num_frames, self.MIN_FRAMES_MODEL, self.MAX_FRAMES_MODEL))
179
 
 
180
  current_seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else int(seed)
181
  generator = torch.Generator().manual_seed(current_seed)
182
 
 
183
  result = self.pipe(
184
  image=processed_start,
185
  last_image=processed_end,
186
  prompt=prompt,
187
+ negative_prompt=negative_prompt if negative_prompt is not None else self.default_negative_prompt,
188
  height=target_height,
189
  width=target_width,
190
  num_frames=num_frames,
 
192
  guidance_scale_2=float(guidance_scale_2),
193
  num_inference_steps=int(steps),
194
  generator=generator,
195
+ output_type=output_type,
196
  )
197
 
198
  frames = result.frames[0]
199
 
 
200
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
201
  video_path = tmp.name
202
  export_to_video(frames, video_path, fps=self.FIXED_FPS)