vict0rsch commited on
Commit
be66585
1 Parent(s): 0d2cc0b

fix single pil image handling

Browse files
Files changed (1) hide show
  1. climategan_wrapper.py +11 -1
climategan_wrapper.py CHANGED
@@ -169,6 +169,16 @@ class ClimateGAN:
169
  raise e
170
 
171
  def _preprocess_image(self, img):
 
 
 
 
 
 
 
 
 
 
172
  # rgba to rgb
173
  data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
174
 
@@ -322,7 +332,7 @@ class ClimateGAN:
322
 
323
  if pil_image is not None:
324
  print("Warning: `pil_image` has been provided, it will override `images`")
325
- images = self._preprocess_image(np.array(pil_image)[None, ...])
326
 
327
  # Retrieve numpy events as a dict {event: array[BxHxWxC]}
328
  outputs = self.trainer.infer_all(
 
169
  raise e
170
 
171
  def _preprocess_image(self, img):
172
+ """
173
+ Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array
174
+ in [-1, 1].
175
+
176
+ Args:
177
+ img (np.array): Image to resize crop and rescale
178
+
179
+ Returns:
180
+ np.array: Resized, cropped and rescaled image
181
+ """
182
  # rgba to rgb
183
  data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
184
 
 
332
 
333
  if pil_image is not None:
334
  print("Warning: `pil_image` has been provided, it will override `images`")
335
+ images = self._preprocess_image(np.array(pil_image))[None, ...]
336
 
337
  # Retrieve numpy events as a dict {event: array[BxHxWxC]}
338
  outputs = self.trainer.infer_all(