X-HighVoltage-X commited on
Commit
1faa9e5
verified
1 Parent(s): 8a5db41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -117,9 +117,22 @@ def inpaint(
117
  def callback_on_step_end(pipe_self, i, t, callback_kwargs):
118
  latents = callback_kwargs.get("latents", None)
119
  if latents is not None:
120
- latents = latents * mask_t + latents_orig * (1 - mask_t)
 
 
 
 
 
 
 
 
 
 
 
121
  callback_kwargs["latents"] = latents
122
- print("callback_kwargsssssssssssssss", callback_kwargs.keys())
 
 
123
  return callback_kwargs
124
 
125
  callback_on_step_end_tensor_inputs = ["latents"]
 
117
  def callback_on_step_end(pipe_self, i, t, callback_kwargs):
118
  latents = callback_kwargs.get("latents", None)
119
  if latents is not None:
120
+ # Ajustar din谩micamente los tama帽os al del tensor actual
121
+ if mask_t.shape[-2:] != latents.shape[-2:]:
122
+ resized_mask = torch.nn.functional.interpolate(mask_t, size=latents.shape[-2:], mode="nearest")
123
+ else:
124
+ resized_mask = mask_t
125
+
126
+ if latents_orig.shape[-2:] != latents.shape[-2:]:
127
+ resized_latents_orig = torch.nn.functional.interpolate(latents_orig, size=latents.shape[-2:], mode="nearest")
128
+ else:
129
+ resized_latents_orig = latents_orig
130
+
131
+ latents = latents * resized_mask + resized_latents_orig * (1 - resized_mask)
132
  callback_kwargs["latents"] = latents
133
+
134
+ # 馃攳 Solo para depuraci贸n
135
+ print(f"[Callback] step={i}, keys={list(callback_kwargs.keys())}, latents={latents.shape if latents is not None else None}")
136
  return callback_kwargs
137
 
138
  callback_on_step_end_tensor_inputs = ["latents"]