vict0rsch commited on
Commit
490814b
·
1 Parent(s): e2d1ef3

handle single PIL Image

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. climategan_wrapper.py +25 -2
app.py CHANGED
@@ -35,7 +35,7 @@ def predict(cg: ClimateGAN, api_key):
35
  "Both": "both",
36
  }
37
 
38
- output_dict = cg.infer_single(img_np, painters[painter])
39
 
40
  input_image = output_dict["input"]
41
  masked_input = output_dict["masked_input"]
 
35
  "Both": "both",
36
  }
37
 
38
+ output_dict = cg.infer_single(img_np, painters[painter], as_pil_image=True)
39
 
40
  input_image = output_dict["input"]
41
  masked_input = output_dict["masked_input"]
climategan_wrapper.py CHANGED
@@ -192,6 +192,7 @@ class ClimateGAN:
192
  "stable_flood",
193
  "stable_copy_flood",
194
  ],
 
195
  ):
196
  """
197
  Infers the image with the ClimateGAN model.
@@ -245,9 +246,18 @@ class ClimateGAN:
245
  if isinstance(orig_image, str)
246
  else orig_image
247
  )
 
 
 
 
 
248
  image = self._preprocess_image(image_array)
249
  output_dict = self.infer_preprocessed_batch(
250
- image[None, ...], painter, prompt, concats
 
 
 
 
251
  )
252
  return {k: v[0] for k, v in output_dict.items()}
253
 
@@ -263,6 +273,7 @@ class ClimateGAN:
263
  "stable_flood",
264
  "stable_copy_flood",
265
  ],
 
266
  ):
267
  """
268
  Infers ClimateGAN predictions on a batch of preprocessed images.
@@ -293,6 +304,8 @@ class ClimateGAN:
293
  in a new `{original_stem}_concat` image written. Defaults to:
294
  ["input", "masked_input", "climategan_flood", "stable_flood",
295
  "stable_copy_flood"].
 
 
296
 
297
  Returns:
298
  dict: a dictionary containing the output images
@@ -307,6 +320,10 @@ class ClimateGAN:
307
  if painter == "stable_diffusion":
308
  ignore_event.add("flood")
309
 
 
 
 
 
310
  # Retrieve numpy events as a dict {event: array[BxHxWxC]}
311
  outputs = self.trainer.infer_all(
312
  images,
@@ -336,8 +353,14 @@ class ClimateGAN:
336
  mask = outputs["mask"].squeeze(1)
337
  input_images = (
338
  torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
 
 
 
 
 
 
 
339
  )
340
- input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
341
  floods = self.sdip_pipeline(
342
  prompt=[prompt] * images.shape[0],
343
  image=input_images,
 
192
  "stable_flood",
193
  "stable_copy_flood",
194
  ],
195
+ as_pil_image=False,
196
  ):
197
  """
198
  Infers the image with the ClimateGAN model.
 
246
  if isinstance(orig_image, str)
247
  else orig_image
248
  )
249
+
250
+ pil_image = None
251
+ if as_pil_image:
252
+ pil_image = Image.fromarray(image_array)
253
+
254
  image = self._preprocess_image(image_array)
255
  output_dict = self.infer_preprocessed_batch(
256
+ images=image[None, ...],
257
+ painter=painter,
258
+ prompt=prompt,
259
+ concats=concats,
260
+ pil_image=pil_image,
261
  )
262
  return {k: v[0] for k, v in output_dict.items()}
263
 
 
273
  "stable_flood",
274
  "stable_copy_flood",
275
  ],
276
+ pil_image=None,
277
  ):
278
  """
279
  Infers ClimateGAN predictions on a batch of preprocessed images.
 
304
  in a new `{original_stem}_concat` image written. Defaults to:
305
  ["input", "masked_input", "climategan_flood", "stable_flood",
306
  "stable_copy_flood"].
307
+ pil_image (PIL.Image, optional): The original PIL image. If provided,
308
+ will be used for a single inference (batch_size=1)
309
 
310
  Returns:
311
  dict: a dictionary containing the output images
 
320
  if painter == "stable_diffusion":
321
  ignore_event.add("flood")
322
 
323
+ if pil_image is not None:
324
+ print("Warning: `pil_image` has been provided, it will override `images`")
325
+ images = self._preprocess_image(np.array(pil_image)[None, ...])
326
+
327
  # Retrieve numpy events as a dict {event: array[BxHxWxC]}
328
  outputs = self.trainer.infer_all(
329
  images,
 
353
  mask = outputs["mask"].squeeze(1)
354
  input_images = (
355
  torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
356
+ if pil_image is None
357
+ else pil_image
358
+ )
359
+ input_mask = (
360
+ torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
361
+ if pil_image is None
362
+ else Image.fromarray(mask[0])
363
  )
 
364
  floods = self.sdip_pipeline(
365
  prompt=[prompt] * images.shape[0],
366
  image=input_images,