x2XcarleX2x commited on
Commit
0374496
·
verified ·
1 Parent(s): 87ffc1a

Update aduc_framework/utils/callbacks.py

Browse files
Files changed (1) hide show
  1. aduc_framework/utils/callbacks.py +35 -17
aduc_framework/utils/callbacks.py CHANGED
@@ -1,38 +1,56 @@
1
  # aduc_framework/utils/callbacks.py
2
 
 
3
  import torch
4
  import numpy as np
 
 
5
 
6
  class DenoiseStepLogger:
7
  """
8
- Callback simplificado que apenas decodifica latentes para um tensor de vídeo em pixels (como um array NumPy).
9
- A lógica de salvar em arquivo foi movida para o manager para melhor controle do fluxo com 'yield'.
10
  """
11
- def __init__(self, pipe):
12
  self.pipe = pipe
13
- # Mantém os tensores na CPU por padrão. Eles serão movidos para o dispositivo
14
- # correto no momento do uso para garantir compatibilidade com `device_map="auto"`.
 
15
  self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1)
16
  self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1)
17
 
18
- def decode_latents_to_video_tensor(self, latents: torch.Tensor) -> np.ndarray:
19
  """
20
- Decodifica um tensor latente 5D para um array numpy de frames de vídeo.
21
-
22
- Args:
23
- latents (torch.Tensor): O tensor 5D vindo do processo de denoising.
24
-
25
- Returns:
26
- np.ndarray: Um array NumPy representando o lote de frames de vídeo (B, F, H, W, C).
27
  """
28
- # Garante que as operações aconteçam no mesmo dispositivo que os latentes
29
  correct_device = latents.device
30
  latents_unscaled = latents / self.latents_std.to(correct_device) + self.latents_mean.to(correct_device)
31
 
32
  latents_unscaled = latents_unscaled.to(self.pipe.vae.dtype)
33
  video_tensor = self.pipe.vae.decode(latents_unscaled, return_dict=False)[0]
34
 
35
- # O post-processador da pipeline converte o tensor de vídeo para um array NumPy
36
- video_np = self.pipe.video_processor.postprocess_video(video=video_tensor, output_type="np")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return video_np
 
1
  # aduc_framework/utils/callbacks.py
2
 
3
+ import imageio
4
  import torch
5
  import numpy as np
6
+ import tempfile
7
+ from diffusers.utils.export_utils import export_to_video
8
 
9
  class DenoiseStepLogger:
10
  """
11
+ Callback que, em cada passo do denoising, decodifica a sequência de vídeo
12
+ inteira e a salva como um clipe MP4 individual.
13
  """
14
+ def __init__(self, pipe, fps=8):
15
  self.pipe = pipe
16
+ self.fps = fps
17
+ # Armazena os caminhos para os vídeos gerados em cada passo
18
+ self.step_video_paths = []
19
  self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1)
20
  self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1)
21
 
22
+ def decode_latents_to_video_path(self, latents: torch.Tensor, step: int) -> str:
23
  """
24
+ Decodifica um tensor latente 5D, salva o vídeo resultante em um arquivo
25
+ temporário e retorna o caminho para esse arquivo.
 
 
 
 
 
26
  """
 
27
  correct_device = latents.device
28
  latents_unscaled = latents / self.latents_std.to(correct_device) + self.latents_mean.to(correct_device)
29
 
30
  latents_unscaled = latents_unscaled.to(self.pipe.vae.dtype)
31
  video_tensor = self.pipe.vae.decode(latents_unscaled, return_dict=False)[0]
32
 
33
+ # O resultado é um lote de frames de vídeo
34
+ frames = self.pipe.video_processor.postprocess_video(video=video_tensor, output_type="np")
35
+
36
+ with tempfile.NamedTemporaryFile(suffix=f"_step_{step+1}.mp4", delete=False) as tmp:
37
+ video_path = tmp.name
38
+
39
+ export_to_video(frames[0], video_path, fps=self.fps)
40
+ return video_path
41
+
42
+ def __call__(self, pipe, step: int, timestep: int, callback_kwargs: dict):
43
+ """
44
+ Chamado pela pipeline a cada passo.
45
+ """
46
+ print(f" -> Callback: Decodificando vídeo completo do passo de denoising {step+1}...")
47
+ latents = callback_kwargs["latents"]
48
+
49
+ # Gera o vídeo para o passo atual e armazena seu caminho
50
+ video_path = self.decode_latents_to_video_path(latents, step)
51
+ self.step_video_paths.append(video_path)
52
+
53
+ yield None, None, self.denoising_step_videos
54
+
55
 
56
+ return callback_kwargs