x2XcarleX2x commited on
Commit
348f29f
·
verified ·
1 Parent(s): 2d3e403

Update aduc_framework/utils/callbacks.py

Browse files
Files changed (1) hide show
  1. aduc_framework/utils/callbacks.py +20 -41
aduc_framework/utils/callbacks.py CHANGED
@@ -16,40 +16,40 @@ class DenoiseStepLogger:
16
  def __init__(self, pipe):
17
  self.pipe = pipe
18
  self.intermediate_frames = []
19
- self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(pipe.device)
20
- self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(pipe.device)
 
 
21
 
 
22
  def decode_latents_to_pil(self, latents: torch.Tensor) -> Image.Image:
23
  """Decodifica um tensor de latents para uma única imagem PIL."""
24
- latents = latents.to(self.pipe.vae.dtype)
25
- latents = latents / self.latents_std + self.latents_mean
26
- decoded_video_tensor = self.pipe.vae.decode(latents, return_dict=False)[0]
27
 
28
- # Pega o primeiro frame do lote de vídeo decodificado
29
- frame_tensor = decoded_video_tensor[0, :, 0, :, :]
30
 
31
- # Normaliza o tensor de [ -1, 1] para [0, 1]
32
- frame_tensor = (frame_tensor / 2 + 0.5).clamp(0, 1)
 
 
 
 
33
 
34
- # Converte para array NumPy e depois para imagem PIL
 
35
  frame_np = frame_tensor.cpu().permute(1, 2, 0).float().numpy()
36
  pil_image = Image.fromarray((frame_np * 255).astype(np.uint8))
37
  return pil_image
 
38
 
39
  def __call__(self, pipe, step: int, timestep: int, callback_kwargs: dict):
40
  """
41
  Esta função é chamada pela pipeline da diffusers em cada passo de denoising.
42
- A assinatura está corrigida para aceitar os 5 argumentos padrão.
43
  """
44
  print(f" -> Callback: Capturando frame do passo de denoising {step+1}...")
45
-
46
- # Extrai o tensor de latents do dicionário `callback_kwargs`
47
  latents = callback_kwargs["latents"]
48
-
49
  pil_frame = self.decode_latents_to_pil(latents)
50
  self.intermediate_frames.append(pil_frame)
51
-
52
- # É uma boa prática retornar o dicionário para a pipeline
53
  return callback_kwargs
54
 
55
  def save_as_video(self, output_path: str, fps: int = 5):
@@ -57,14 +57,10 @@ class DenoiseStepLogger:
57
  if not self.intermediate_frames:
58
  print(" -> Callback: Nenhum frame intermediário para salvar como vídeo.")
59
  return
60
-
61
  print(f" -> Callback: Codificando {len(self.intermediate_frames)} frames em vídeo em '{output_path}'...")
62
- # Usa um codec de alta compatibilidade e boa qualidade
63
  writer = imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8, pixelformat='yuv420p')
64
-
65
  for frame in self.intermediate_frames:
66
  writer.append_data(np.array(frame))
67
-
68
  writer.close()
69
  print(" -> Callback: Vídeo de depuração salvo com sucesso.")
70
 
@@ -75,51 +71,34 @@ class DenoiseStepLogger:
75
  if not self.intermediate_frames:
76
  print(" -> Callback: Nenhum frame intermediário para criar a grade.")
77
  return None
78
-
79
  print(f" -> Callback: Criando grade de comparação com {len(self.intermediate_frames)} etapas...")
80
-
81
- # Calcula um layout de grade agradável (o mais quadrado possível)
82
  num_images = len(self.intermediate_frames)
83
  cols = math.ceil(math.sqrt(num_images))
84
  rows = math.ceil(num_images / cols)
85
-
86
  frame_w, frame_h = self.intermediate_frames[0].size
87
- grid_w = frame_w * cols
88
- grid_h = frame_h * rows
89
-
90
  grid_image = Image.new('RGB', (grid_w, grid_h), (20, 20, 20))
91
  draw = ImageDraw.Draw(grid_image)
92
-
93
- # Tenta carregar uma fonte, usa uma padrão se falhar
94
  try:
95
- # Em muitos sistemas Linux/Docker, esta fonte estará disponível
96
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
97
  if not os.path.exists(font_path):
98
- # Fallback para um caminho comum em contêineres Debian
99
  font_path = "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf"
100
  font = ImageFont.truetype(font_path, size=32)
101
  except IOError:
102
  print(" -> Callback WARNING: Fonte não encontrada. Usando fonte padrão.")
103
  font = ImageFont.load_default()
104
-
105
- # Cola cada frame na grade e desenha a legenda
106
  for i, frame in enumerate(self.intermediate_frames):
107
- x = (i % cols) * frame_w
108
- y = (i // cols) * frame_h
109
  grid_image.paste(frame, (x, y))
110
-
111
  text = f"Passo {i+1}"
112
  text_origin = (x + 10, y + 10)
113
-
114
  try:
115
  text_bbox = draw.textbbox(text_origin, text, font=font)
116
- except AttributeError: # Fallback para Pillow < 9.2.0
117
  text_w, text_h = draw.textsize(text, font=font)
118
  text_bbox = (text_origin[0], text_origin[1], text_origin[0] + text_w, text_origin[1] + text_h)
119
-
120
  rect_coords = (text_bbox[0] - 5, text_bbox[1] - 5, text_bbox[2] + 5, text_bbox[3] + 5)
121
- draw.rectangle(rect_coords, fill=(0, 0, 0, 180)) # Fundo preto semi-transparente
122
  draw.text(text_origin, text, font=font, fill=(255, 255, 255))
123
-
124
  print(" -> Callback: Grade de comparação criada com sucesso.")
125
  return grid_image
 
16
  def __init__(self, pipe):
17
  self.pipe = pipe
18
  self.intermediate_frames = []
19
+ # Mantém os tensores na CPU por padrão durante a inicialização.
20
+ # Eles serão movidos para o dispositivo correto no momento do uso.
21
+ self.latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, pipe.vae.config.z_dim, 1, 1, 1)
22
+ self.latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1)
23
 
24
+ ### INÍCIO DA SEÇÃO CORRIGIDA ###
25
  def decode_latents_to_pil(self, latents: torch.Tensor) -> Image.Image:
26
  """Decodifica um tensor de latents para uma única imagem PIL."""
 
 
 
27
 
28
+ # Pega o dispositivo correto do tensor de entrada `latents`
29
+ correct_device = latents.device
30
 
31
+ # Move os tensores de média e desvio padrão para o mesmo dispositivo dos latents
32
+ # antes de realizar a operação. Isso evita o erro de "device meta".
33
+ latents_unscaled = latents / self.latents_std.to(correct_device) + self.latents_mean.to(correct_device)
34
+
35
+ latents_unscaled = latents_unscaled.to(self.pipe.vae.dtype)
36
+ decoded_video_tensor = self.pipe.vae.decode(latents_unscaled, return_dict=False)[0]
37
 
38
+ frame_tensor = decoded_video_tensor[0, :, 0, :, :]
39
+ frame_tensor = (frame_tensor / 2 + 0.5).clamp(0, 1)
40
  frame_np = frame_tensor.cpu().permute(1, 2, 0).float().numpy()
41
  pil_image = Image.fromarray((frame_np * 255).astype(np.uint8))
42
  return pil_image
43
+ ### FIM DA SEÇÃO CORRIGIDA ###
44
 
45
  def __call__(self, pipe, step: int, timestep: int, callback_kwargs: dict):
46
  """
47
  Esta função é chamada pela pipeline da diffusers em cada passo de denoising.
 
48
  """
49
  print(f" -> Callback: Capturando frame do passo de denoising {step+1}...")
 
 
50
  latents = callback_kwargs["latents"]
 
51
  pil_frame = self.decode_latents_to_pil(latents)
52
  self.intermediate_frames.append(pil_frame)
 
 
53
  return callback_kwargs
54
 
55
  def save_as_video(self, output_path: str, fps: int = 5):
 
57
  if not self.intermediate_frames:
58
  print(" -> Callback: Nenhum frame intermediário para salvar como vídeo.")
59
  return
 
60
  print(f" -> Callback: Codificando {len(self.intermediate_frames)} frames em vídeo em '{output_path}'...")
 
61
  writer = imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8, pixelformat='yuv420p')
 
62
  for frame in self.intermediate_frames:
63
  writer.append_data(np.array(frame))
 
64
  writer.close()
65
  print(" -> Callback: Vídeo de depuração salvo com sucesso.")
66
 
 
71
  if not self.intermediate_frames:
72
  print(" -> Callback: Nenhum frame intermediário para criar a grade.")
73
  return None
 
74
  print(f" -> Callback: Criando grade de comparação com {len(self.intermediate_frames)} etapas...")
 
 
75
  num_images = len(self.intermediate_frames)
76
  cols = math.ceil(math.sqrt(num_images))
77
  rows = math.ceil(num_images / cols)
 
78
  frame_w, frame_h = self.intermediate_frames[0].size
79
+ grid_w, grid_h = frame_w * cols, frame_h * rows
 
 
80
  grid_image = Image.new('RGB', (grid_w, grid_h), (20, 20, 20))
81
  draw = ImageDraw.Draw(grid_image)
 
 
82
  try:
 
83
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
84
  if not os.path.exists(font_path):
 
85
  font_path = "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf"
86
  font = ImageFont.truetype(font_path, size=32)
87
  except IOError:
88
  print(" -> Callback WARNING: Fonte não encontrada. Usando fonte padrão.")
89
  font = ImageFont.load_default()
 
 
90
  for i, frame in enumerate(self.intermediate_frames):
91
+ x, y = (i % cols) * frame_w, (i // cols) * frame_h
 
92
  grid_image.paste(frame, (x, y))
 
93
  text = f"Passo {i+1}"
94
  text_origin = (x + 10, y + 10)
 
95
  try:
96
  text_bbox = draw.textbbox(text_origin, text, font=font)
97
+ except AttributeError:
98
  text_w, text_h = draw.textsize(text, font=font)
99
  text_bbox = (text_origin[0], text_origin[1], text_origin[0] + text_w, text_origin[1] + text_h)
 
100
  rect_coords = (text_bbox[0] - 5, text_bbox[1] - 5, text_bbox[2] + 5, text_bbox[3] + 5)
101
+ draw.rectangle(rect_coords, fill=(0, 0, 0, 180))
102
  draw.text(text_origin, text, font=font, fill=(255, 255, 255))
 
103
  print(" -> Callback: Grade de comparação criada com sucesso.")
104
  return grid_image