X-HighVoltage-X commited on
Commit
c22d93c
·
verified ·
1 Parent(s): 1faa9e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -98,41 +98,49 @@ def inpaint(
98
  if preserve_unmasked:
99
  np_img = np.array(image).astype(np.float32) / 255.0
100
  img_t = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).to(pipe.device)
 
101
  img_t = (img_t * 2 - 1).to(dtype=pipe.vae.dtype)
102
 
103
  np_mask = np.array(mask).astype(np.float32) / 255.0
104
  mask_t = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0).to(pipe.device)
 
105
 
106
  with torch.no_grad():
107
  latents_orig = pipe.vae.encode(img_t).latent_dist.sample()
108
- scaling = getattr(pipe.vae.config, "scaling_factor", getattr(pipe, "vae_scale_factor", None))
109
- if scaling is not None:
110
- latents_orig = latents_orig * scaling
111
 
112
  # Ajustar máscara al tamaño de los latentes
113
- mask_t = torch.nn.functional.interpolate(
114
- mask_t, size=latents_orig.shape[-2:], mode="nearest"
115
- )
116
 
117
  def callback_on_step_end(pipe_self, i, t, callback_kwargs):
118
  latents = callback_kwargs.get("latents", None)
119
  if latents is not None:
 
 
 
 
 
120
  # Ajustar dinámicamente los tamaños al del tensor actual
121
- if mask_t.shape[-2:] != latents.shape[-2:]:
122
- resized_mask = torch.nn.functional.interpolate(mask_t, size=latents.shape[-2:], mode="nearest")
 
 
 
123
  else:
124
  resized_mask = mask_t
125
 
126
- if latents_orig.shape[-2:] != latents.shape[-2:]:
127
- resized_latents_orig = torch.nn.functional.interpolate(latents_orig, size=latents.shape[-2:], mode="nearest")
128
  else:
129
  resized_latents_orig = latents_orig
130
 
 
131
  latents = latents * resized_mask + resized_latents_orig * (1 - resized_mask)
132
  callback_kwargs["latents"] = latents
133
 
134
- # 🔍 Solo para depuración
135
- print(f"[Callback] step={i}, keys={list(callback_kwargs.keys())}, latents={latents.shape if latents is not None else None}")
136
  return callback_kwargs
137
 
138
  callback_on_step_end_tensor_inputs = ["latents"]
 
98
  if preserve_unmasked:
99
  np_img = np.array(image).astype(np.float32) / 255.0
100
  img_t = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).to(pipe.device)
101
+ img_t = F.interpolate(img_t, size=(height, width), mode='bilinear', align_corners=False)
102
  img_t = (img_t * 2 - 1).to(dtype=pipe.vae.dtype)
103
 
104
  np_mask = np.array(mask).astype(np.float32) / 255.0
105
  mask_t = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0).to(pipe.device)
106
+ mask_t = F.interpolate(mask_t, size=(height, width), mode='nearest')
107
 
108
  with torch.no_grad():
109
  latents_orig = pipe.vae.encode(img_t).latent_dist.sample()
110
+ scaling = getattr(pipe.vae.config, "scaling_factor", getattr(pipe, "vae_scale_factor", 0.13025))
111
+ latents_orig = latents_orig * scaling
 
112
 
113
  # Ajustar máscara al tamaño de los latentes
114
+ latent_height = latents_orig.shape[2]
115
+ latent_width = latents_orig.shape[3]
116
+ mask_t = F.interpolate(mask_t, size=(latent_height, latent_width), mode="nearest")
117
 
118
  def callback_on_step_end(pipe_self, i, t, callback_kwargs):
119
  latents = callback_kwargs.get("latents", None)
120
  if latents is not None:
121
+ # Verificar que tengamos 4 dimensiones [batch, channels, height, width]
122
+ if latents.dim() != 4:
123
+ print(f"⚠️ Warning: latents has {latents.dim()} dimensions, expected 4")
124
+ return callback_kwargs
125
+
126
  # Ajustar dinámicamente los tamaños al del tensor actual
127
+ current_height = latents.shape[2]
128
+ current_width = latents.shape[3]
129
+
130
+ if mask_t.shape[-2:] != (current_height, current_width):
131
+ resized_mask = F.interpolate(mask_t, size=(current_height, current_width), mode="nearest")
132
  else:
133
  resized_mask = mask_t
134
 
135
+ if latents_orig.shape[-2:] != (current_height, current_width):
136
+ resized_latents_orig = F.interpolate(latents_orig, size=(current_height, current_width), mode="nearest")
137
  else:
138
  resized_latents_orig = latents_orig
139
 
140
+ # Mezclar solo en las áreas no enmascaradas
141
  latents = latents * resized_mask + resized_latents_orig * (1 - resized_mask)
142
  callback_kwargs["latents"] = latents
143
 
 
 
144
  return callback_kwargs
145
 
146
  callback_on_step_end_tensor_inputs = ["latents"]