Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -98,41 +98,49 @@ def inpaint(
|
|
| 98 |
if preserve_unmasked:
|
| 99 |
np_img = np.array(image).astype(np.float32) / 255.0
|
| 100 |
img_t = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).to(pipe.device)
|
|
|
|
| 101 |
img_t = (img_t * 2 - 1).to(dtype=pipe.vae.dtype)
|
| 102 |
|
| 103 |
np_mask = np.array(mask).astype(np.float32) / 255.0
|
| 104 |
mask_t = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0).to(pipe.device)
|
|
|
|
| 105 |
|
| 106 |
with torch.no_grad():
|
| 107 |
latents_orig = pipe.vae.encode(img_t).latent_dist.sample()
|
| 108 |
-
scaling = getattr(pipe.vae.config, "scaling_factor", getattr(pipe, "vae_scale_factor",
|
| 109 |
-
|
| 110 |
-
latents_orig = latents_orig * scaling
|
| 111 |
|
| 112 |
# Ajustar máscara al tamaño de los latentes
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
)
|
| 116 |
|
| 117 |
def callback_on_step_end(pipe_self, i, t, callback_kwargs):
|
| 118 |
latents = callback_kwargs.get("latents", None)
|
| 119 |
if latents is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
# Ajustar dinámicamente los tamaños al del tensor actual
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
else:
|
| 124 |
resized_mask = mask_t
|
| 125 |
|
| 126 |
-
if latents_orig.shape[-2:] !=
|
| 127 |
-
resized_latents_orig =
|
| 128 |
else:
|
| 129 |
resized_latents_orig = latents_orig
|
| 130 |
|
|
|
|
| 131 |
latents = latents * resized_mask + resized_latents_orig * (1 - resized_mask)
|
| 132 |
callback_kwargs["latents"] = latents
|
| 133 |
|
| 134 |
-
# 🔍 Solo para depuración
|
| 135 |
-
print(f"[Callback] step={i}, keys={list(callback_kwargs.keys())}, latents={latents.shape if latents is not None else None}")
|
| 136 |
return callback_kwargs
|
| 137 |
|
| 138 |
callback_on_step_end_tensor_inputs = ["latents"]
|
|
|
|
| 98 |
if preserve_unmasked:
|
| 99 |
np_img = np.array(image).astype(np.float32) / 255.0
|
| 100 |
img_t = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).to(pipe.device)
|
| 101 |
+
img_t = F.interpolate(img_t, size=(height, width), mode='bilinear', align_corners=False)
|
| 102 |
img_t = (img_t * 2 - 1).to(dtype=pipe.vae.dtype)
|
| 103 |
|
| 104 |
np_mask = np.array(mask).astype(np.float32) / 255.0
|
| 105 |
mask_t = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0).to(pipe.device)
|
| 106 |
+
mask_t = F.interpolate(mask_t, size=(height, width), mode='nearest')
|
| 107 |
|
| 108 |
with torch.no_grad():
|
| 109 |
latents_orig = pipe.vae.encode(img_t).latent_dist.sample()
|
| 110 |
+
scaling = getattr(pipe.vae.config, "scaling_factor", getattr(pipe, "vae_scale_factor", 0.13025))
|
| 111 |
+
latents_orig = latents_orig * scaling
|
|
|
|
| 112 |
|
| 113 |
# Ajustar máscara al tamaño de los latentes
|
| 114 |
+
latent_height = latents_orig.shape[2]
|
| 115 |
+
latent_width = latents_orig.shape[3]
|
| 116 |
+
mask_t = F.interpolate(mask_t, size=(latent_height, latent_width), mode="nearest")
|
| 117 |
|
| 118 |
def callback_on_step_end(pipe_self, i, t, callback_kwargs):
|
| 119 |
latents = callback_kwargs.get("latents", None)
|
| 120 |
if latents is not None:
|
| 121 |
+
# Verificar que tengamos 4 dimensiones [batch, channels, height, width]
|
| 122 |
+
if latents.dim() != 4:
|
| 123 |
+
print(f"⚠️ Warning: latents has {latents.dim()} dimensions, expected 4")
|
| 124 |
+
return callback_kwargs
|
| 125 |
+
|
| 126 |
# Ajustar dinámicamente los tamaños al del tensor actual
|
| 127 |
+
current_height = latents.shape[2]
|
| 128 |
+
current_width = latents.shape[3]
|
| 129 |
+
|
| 130 |
+
if mask_t.shape[-2:] != (current_height, current_width):
|
| 131 |
+
resized_mask = F.interpolate(mask_t, size=(current_height, current_width), mode="nearest")
|
| 132 |
else:
|
| 133 |
resized_mask = mask_t
|
| 134 |
|
| 135 |
+
if latents_orig.shape[-2:] != (current_height, current_width):
|
| 136 |
+
resized_latents_orig = F.interpolate(latents_orig, size=(current_height, current_width), mode="nearest")
|
| 137 |
else:
|
| 138 |
resized_latents_orig = latents_orig
|
| 139 |
|
| 140 |
+
# Mezclar solo en las áreas no enmascaradas
|
| 141 |
latents = latents * resized_mask + resized_latents_orig * (1 - resized_mask)
|
| 142 |
callback_kwargs["latents"] = latents
|
| 143 |
|
|
|
|
|
|
|
| 144 |
return callback_kwargs
|
| 145 |
|
| 146 |
callback_on_step_end_tensor_inputs = ["latents"]
|