x2XcarleX2x commited on
Commit
91ef7b6
·
verified ·
1 Parent(s): 80198c5

Update aduc_framework/utils/callbacks.py

Browse files
Files changed (1) hide show
  1. aduc_framework/utils/callbacks.py +30 -81
aduc_framework/utils/callbacks.py CHANGED
@@ -2,103 +2,52 @@
2
 
3
  import imageio
4
  import torch
5
- from PIL import Image, ImageDraw, ImageFont
6
  import numpy as np
7
- import math
8
- import os
9
 
10
  class DenoiseStepLogger:
11
  """
12
- Uma classe de callback que "espiona" o processo de denoising.
13
- Ela captura frames intermediários, pode salvá-los como um vídeo de processo
14
- e também pode criar uma grade de comparação visual de todas as etapas.
15
  """
16
- def __init__(self, pipe):
17
  self.pipe = pipe
18
- self.intermediate_frames = []
19
- # Mantém os tensores na CPU por padrão durante a inicialização.
20
- # Eles serão movidos para o dispositivo correto no momento do uso.
21
  self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1)
22
  self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1)
23
 
24
- ### INÍCIO DA SEÇÃO CORRIGIDA ###
25
- def decode_latents_to_pil(self, latents: torch.Tensor) -> Image.Image:
26
- """Decodifica um tensor de latents para uma única imagem PIL."""
27
-
28
- # Pega o dispositivo correto do tensor de entrada `latents`
29
  correct_device = latents.device
30
-
31
- # Move os tensores de média e desvio padrão para o mesmo dispositivo dos latents
32
- # antes de realizar a operação. Isso evita o erro de "device meta".
33
  latents_unscaled = latents / self.latents_std.to(correct_device) + self.latents_mean.to(correct_device)
34
 
35
  latents_unscaled = latents_unscaled.to(self.pipe.vae.dtype)
36
- decoded_video_tensor = self.pipe.vae.decode(latents_unscaled, return_dict=False)[0]
 
 
 
 
 
 
37
 
38
- frame_tensor = decoded_video_tensor[0, :, 0, :, :]
39
- frame_tensor = (frame_tensor / 2 + 0.5).clamp(0, 1)
40
- frame_np = frame_tensor.cpu().permute(1, 2, 0).float().numpy()
41
- pil_image = Image.fromarray((frame_np * 255).astype(np.uint8))
42
- return pil_image
43
- ### FIM DA SEÇÃO CORRIGIDA ###
44
 
45
  def __call__(self, pipe, step: int, timestep: int, callback_kwargs: dict):
46
  """
47
- Esta função é chamada pela pipeline da diffusers em cada passo de denoising.
48
  """
49
- print(f" -> Callback: Capturando frame do passo de denoising {step+1}...")
50
  latents = callback_kwargs["latents"]
51
- pil_frame = self.decode_latents_to_pil(latents)
52
- self.intermediate_frames.append(pil_frame)
53
- return callback_kwargs
54
-
55
- def save_as_video(self, output_path: str, fps: int = 5):
56
- """Salva os frames intermediários capturados como um vídeo MP4."""
57
- if not self.intermediate_frames:
58
- print(" -> Callback: Nenhum frame intermediário para salvar como vídeo.")
59
- return
60
- print(f" -> Callback: Codificando {len(self.intermediate_frames)} frames em vídeo em '{output_path}'...")
61
- writer = imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8, pixelformat='yuv420p')
62
- for frame in self.intermediate_frames:
63
- writer.append_data(np.array(frame))
64
- writer.close()
65
- print(" -> Callback: Vídeo de depuração salvo com sucesso.")
66
-
67
- def create_steps_grid(self) -> Image.Image:
68
- """
69
- Organiza todos os frames intermediários capturados em uma única imagem de grade para comparação.
70
- """
71
- if not self.intermediate_frames:
72
- print(" -> Callback: Nenhum frame intermediário para criar a grade.")
73
- return None
74
- print(f" -> Callback: Criando grade de comparação com {len(self.intermediate_frames)} etapas...")
75
- num_images = len(self.intermediate_frames)
76
- cols = math.ceil(math.sqrt(num_images))
77
- rows = math.ceil(num_images / cols)
78
- frame_w, frame_h = self.intermediate_frames[0].size
79
- grid_w, grid_h = frame_w * cols, frame_h * rows
80
- grid_image = Image.new('RGB', (grid_w, grid_h), (20, 20, 20))
81
- draw = ImageDraw.Draw(grid_image)
82
- try:
83
- font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
84
- if not os.path.exists(font_path):
85
- font_path = "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf"
86
- font = ImageFont.truetype(font_path, size=32)
87
- except IOError:
88
- print(" -> Callback WARNING: Fonte não encontrada. Usando fonte padrão.")
89
- font = ImageFont.load_default()
90
- for i, frame in enumerate(self.intermediate_frames):
91
- x, y = (i % cols) * frame_w, (i // cols) * frame_h
92
- grid_image.paste(frame, (x, y))
93
- text = f"Passo {i+1}"
94
- text_origin = (x + 10, y + 10)
95
- try:
96
- text_bbox = draw.textbbox(text_origin, text, font=font)
97
- except AttributeError:
98
- text_w, text_h = draw.textsize(text, font=font)
99
- text_bbox = (text_origin[0], text_origin[1], text_origin[0] + text_w, text_origin[1] + text_h)
100
- rect_coords = (text_bbox[0] - 5, text_bbox[1] - 5, text_bbox[2] + 5, text_bbox[3] + 5)
101
- draw.rectangle(rect_coords, fill=(0, 0, 0, 180))
102
- draw.text(text_origin, text, font=font, fill=(255, 255, 255))
103
- print(" -> Callback: Grade de comparação criada com sucesso.")
104
- return grid_image
 
2
 
3
  import imageio
4
  import torch
 
5
  import numpy as np
6
+ import tempfile
7
+ from diffusers.utils.export_utils import export_to_video
8
 
9
  class DenoiseStepLogger:
10
  """
11
+ Callback que, em cada passo do denoising, decodifica a sequência de vídeo
12
+ inteira e a salva como um clipe MP4 individual.
 
13
  """
14
+ def __init__(self, pipe, fps=8):
15
  self.pipe = pipe
16
+ self.fps = fps
17
+ # Armazena os caminhos para os vídeos gerados em cada passo
18
+ self.step_video_paths = []
19
  self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1)
20
  self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1)
21
 
22
+ def decode_latents_to_video_path(self, latents: torch.Tensor, step: int) -> str:
23
+ """
24
+ Decodifica um tensor latente 5D, salva o vídeo resultante em um arquivo
25
+ temporário e retorna o caminho para esse arquivo.
26
+ """
27
  correct_device = latents.device
 
 
 
28
  latents_unscaled = latents / self.latents_std.to(correct_device) + self.latents_mean.to(correct_device)
29
 
30
  latents_unscaled = latents_unscaled.to(self.pipe.vae.dtype)
31
+ video_tensor = self.pipe.vae.decode(latents_unscaled, return_dict=False)[0]
32
+
33
+ # O resultado já é um lote de frames de vídeo
34
+ frames = self.pipe.video_processor.postprocess_video(video=video_tensor, output_type="np")
35
+
36
+ with tempfile.NamedTemporaryFile(suffix=f"_step_{step+1}.mp4", delete=False) as tmp:
37
+ video_path = tmp.name
38
 
39
+ export_to_video(frames[0], video_path, fps=self.fps)
40
+ return video_path
 
 
 
 
41
 
42
  def __call__(self, pipe, step: int, timestep: int, callback_kwargs: dict):
43
  """
44
+ Chamado pela pipeline a cada passo.
45
  """
46
+ print(f" -> Callback: Decodificando vídeo completo do passo de denoising {step+1}...")
47
  latents = callback_kwargs["latents"]
48
+
49
+ # Gera o vídeo para o passo atual e armazena seu caminho
50
+ video_path = self.decode_latents_to_video_path(latents, step)
51
+ self.step_video_paths.append(video_path)
52
+
53
+ return callback_kwargs