Spaces:
Runtime error
Runtime error
handle single PIL Image
Browse files- app.py +1 -1
- climategan_wrapper.py +25 -2
app.py
CHANGED
@@ -35,7 +35,7 @@ def predict(cg: ClimateGAN, api_key):
|
|
35 |
"Both": "both",
|
36 |
}
|
37 |
|
38 |
-
output_dict = cg.infer_single(img_np, painters[painter])
|
39 |
|
40 |
input_image = output_dict["input"]
|
41 |
masked_input = output_dict["masked_input"]
|
|
|
35 |
"Both": "both",
|
36 |
}
|
37 |
|
38 |
+
output_dict = cg.infer_single(img_np, painters[painter], as_pil_image=True)
|
39 |
|
40 |
input_image = output_dict["input"]
|
41 |
masked_input = output_dict["masked_input"]
|
climategan_wrapper.py
CHANGED
@@ -192,6 +192,7 @@ class ClimateGAN:
|
|
192 |
"stable_flood",
|
193 |
"stable_copy_flood",
|
194 |
],
|
|
|
195 |
):
|
196 |
"""
|
197 |
Infers the image with the ClimateGAN model.
|
@@ -245,9 +246,18 @@ class ClimateGAN:
|
|
245 |
if isinstance(orig_image, str)
|
246 |
else orig_image
|
247 |
)
|
|
|
|
|
|
|
|
|
|
|
248 |
image = self._preprocess_image(image_array)
|
249 |
output_dict = self.infer_preprocessed_batch(
|
250 |
-
image[None, ...],
|
|
|
|
|
|
|
|
|
251 |
)
|
252 |
return {k: v[0] for k, v in output_dict.items()}
|
253 |
|
@@ -263,6 +273,7 @@ class ClimateGAN:
|
|
263 |
"stable_flood",
|
264 |
"stable_copy_flood",
|
265 |
],
|
|
|
266 |
):
|
267 |
"""
|
268 |
Infers ClimateGAN predictions on a batch of preprocessed images.
|
@@ -293,6 +304,8 @@ class ClimateGAN:
|
|
293 |
in a new `{original_stem}_concat` image written. Defaults to:
|
294 |
["input", "masked_input", "climategan_flood", "stable_flood",
|
295 |
"stable_copy_flood"].
|
|
|
|
|
296 |
|
297 |
Returns:
|
298 |
dict: a dictionary containing the output images
|
@@ -307,6 +320,10 @@ class ClimateGAN:
|
|
307 |
if painter == "stable_diffusion":
|
308 |
ignore_event.add("flood")
|
309 |
|
|
|
|
|
|
|
|
|
310 |
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
|
311 |
outputs = self.trainer.infer_all(
|
312 |
images,
|
@@ -336,8 +353,14 @@ class ClimateGAN:
|
|
336 |
mask = outputs["mask"].squeeze(1)
|
337 |
input_images = (
|
338 |
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
)
|
340 |
-
input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
|
341 |
floods = self.sdip_pipeline(
|
342 |
prompt=[prompt] * images.shape[0],
|
343 |
image=input_images,
|
|
|
192 |
"stable_flood",
|
193 |
"stable_copy_flood",
|
194 |
],
|
195 |
+
as_pil_image=False,
|
196 |
):
|
197 |
"""
|
198 |
Infers the image with the ClimateGAN model.
|
|
|
246 |
if isinstance(orig_image, str)
|
247 |
else orig_image
|
248 |
)
|
249 |
+
|
250 |
+
pil_image = None
|
251 |
+
if as_pil_image:
|
252 |
+
pil_image = Image.fromarray(image_array)
|
253 |
+
|
254 |
image = self._preprocess_image(image_array)
|
255 |
output_dict = self.infer_preprocessed_batch(
|
256 |
+
images=image[None, ...],
|
257 |
+
painter=painter,
|
258 |
+
prompt=prompt,
|
259 |
+
concats=concats,
|
260 |
+
pil_image=pil_image,
|
261 |
)
|
262 |
return {k: v[0] for k, v in output_dict.items()}
|
263 |
|
|
|
273 |
"stable_flood",
|
274 |
"stable_copy_flood",
|
275 |
],
|
276 |
+
pil_image=None,
|
277 |
):
|
278 |
"""
|
279 |
Infers ClimateGAN predictions on a batch of preprocessed images.
|
|
|
304 |
in a new `{original_stem}_concat` image written. Defaults to:
|
305 |
["input", "masked_input", "climategan_flood", "stable_flood",
|
306 |
"stable_copy_flood"].
|
307 |
+
pil_image (PIL.Image, optional): The original PIL image. If provided,
|
308 |
+
will be used for a single inference (batch_size=1)
|
309 |
|
310 |
Returns:
|
311 |
dict: a dictionary containing the output images
|
|
|
320 |
if painter == "stable_diffusion":
|
321 |
ignore_event.add("flood")
|
322 |
|
323 |
+
if pil_image is not None:
|
324 |
+
print("Warning: `pil_image` has been provided, it will override `images`")
|
325 |
+
images = self._preprocess_image(np.array(pil_image)[None, ...])
|
326 |
+
|
327 |
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
|
328 |
outputs = self.trainer.infer_all(
|
329 |
images,
|
|
|
353 |
mask = outputs["mask"].squeeze(1)
|
354 |
input_images = (
|
355 |
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
|
356 |
+
if pil_image is None
|
357 |
+
else pil_image
|
358 |
+
)
|
359 |
+
input_mask = (
|
360 |
+
torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
|
361 |
+
if pil_image is None
|
362 |
+
else Image.fromarray(mask[0])
|
363 |
)
|
|
|
364 |
floods = self.sdip_pipeline(
|
365 |
prompt=[prompt] * images.shape[0],
|
366 |
image=input_images,
|