vict0rsch commited on
Commit
cd31093
1 Parent(s): 9a7001e

add stable diffusion painter for gradio app

Browse files
Files changed (5) hide show
  1. app.py +1 -1
  2. climategan/trainer.py +22 -17
  3. climategan/utils.py +2 -2
  4. climategan_wrapper.py +476 -0
  5. inferences.py +0 -108
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import googlemaps
7
  from skimage import io
8
  from urllib import parse
9
- from inferences import ClimateGAN
10
 
11
 
12
  def predict(api_key):
 
6
  import googlemaps
7
  from skimage import io
8
  from urllib import parse
9
+ from climategan_wrapper import ClimateGAN
10
 
11
 
12
  def predict(api_key):
climategan/trainer.py CHANGED
@@ -223,7 +223,7 @@ class Trainer:
223
  bin_value=-1,
224
  half=False,
225
  xla=False,
226
- cloudy=False,
227
  auto_resize_640=False,
228
  ignore_event=set(),
229
  return_masks=False,
@@ -308,24 +308,29 @@ class Trainer:
308
  if xla:
309
  xm.mark_step()
310
 
 
 
311
  if numpy:
312
  with Timer(store=stores.get("numpy", [])):
313
- # normalize to 0-1
314
- flood = normalize(flood).cpu()
315
- smog = normalize(smog).cpu()
316
- wildfire = normalize(wildfire).cpu()
317
-
318
- # convert to numpy
319
- flood = flood.permute(0, 2, 3, 1).numpy()
320
- smog = smog.permute(0, 2, 3, 1).numpy()
321
- wildfire = wildfire.permute(0, 2, 3, 1).numpy()
322
-
323
- # convert to 0-255 uint8
324
- flood = (flood * 255).astype(np.uint8)
325
- smog = (smog * 255).astype(np.uint8)
326
- wildfire = (wildfire * 255).astype(np.uint8)
327
-
328
- output_data = {"flood": flood, "wildfire": wildfire, "smog": smog}
 
 
 
329
  if return_masks:
330
  output_data["mask"] = (
331
  ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
 
223
  bin_value=-1,
224
  half=False,
225
  xla=False,
226
+ cloudy=True,
227
  auto_resize_640=False,
228
  ignore_event=set(),
229
  return_masks=False,
 
308
  if xla:
309
  xm.mark_step()
310
 
311
+ output_data = {}
312
+
313
  if numpy:
314
  with Timer(store=stores.get("numpy", [])):
315
+ if "flood" not in ignore_event:
316
+ # normalize to 0-1
317
+ flood = normalize(flood).cpu()
318
+ # convert to numpy
319
+ flood = flood.permute(0, 2, 3, 1).numpy()
320
+ # convert to 0-255 uint8
321
+ flood = (flood * 255).astype(np.uint8)
322
+ output_data["flood"] = flood
323
+ if "wildfire" not in ignore_event:
324
+ wildfire = normalize(wildfire).cpu()
325
+ wildfire = wildfire.permute(0, 2, 3, 1).numpy()
326
+ wildfire = (wildfire * 255).astype(np.uint8)
327
+ output_data["wildfire"] = wildfire
328
+ if "smog" not in ignore_event:
329
+ smog = normalize(smog).cpu()
330
+ smog = smog.permute(0, 2, 3, 1).numpy()
331
+ smog = (smog * 255).astype(np.uint8)
332
+ output_data["smog"] = smog
333
+
334
  if return_masks:
335
  output_data["mask"] = (
336
  ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
climategan/utils.py CHANGED
@@ -922,9 +922,9 @@ class Timer:
922
  self.store = store
923
  self.precision = precision
924
  self.ignore = ignore
925
- self.cuda = cuda or torch.cuda.is_available()
926
 
927
- if cuda:
928
  self._start_event = torch.cuda.Event(enable_timing=True)
929
  self._end_event = torch.cuda.Event(enable_timing=True)
930
 
 
922
  self.store = store
923
  self.precision = precision
924
  self.ignore = ignore
925
+ self.cuda = cuda if cuda is not None else torch.cuda.is_available()
926
 
927
+ if self.cuda:
928
  self._start_event = torch.cuda.Event(enable_timing=True)
929
  self._end_event = torch.cuda.Event(enable_timing=True)
930
 
climategan_wrapper.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
2
+ # thank you @NimaBoscarino
3
+
4
+ import re
5
+ from pathlib import Path
6
+ from uuid import uuid4
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers import StableDiffusionInpaintPipeline
11
+ from PIL import Image
12
+ from skimage.color import rgba2rgb
13
+ from skimage.transform import resize
14
+
15
+ from climategan.trainer import Trainer
16
+
17
+
18
+ def concat_events(output_dict, events, i=None, axis=1):
19
+ """
20
+ Concatenates the `i`th data in `output_dict` according to the keys listed
21
+ in `events` on dimension `axis`.
22
+
23
+ Args:
24
+ output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
25
+ events to their corresponding data :
26
+ {k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
27
+ events (list[str]): output_dict's keys to concatenate.
28
+ axis (int, optional): Concatenation axis. Defaults to 1.
29
+ """
30
+ cs = [e for e in events if e in output_dict]
31
+ if i is not None:
32
+ return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
33
+ return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))
34
+
35
+
36
+ def clear(folder):
37
+ """
38
+ Deletes all the images without the inference separator "---" in their name.
39
+
40
+ Args:
41
+ folder (Union[str, Path]): The folder to clear.
42
+ """
43
+ for i in list(Path(folder).iterdir()):
44
+ if i.is_file() and "---" in i.stem:
45
+ i.unlink()
46
+
47
+
48
+ def uint8(array, rescale=False):
49
+ """
50
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
51
+ Args:
52
+ array (np.array): array to modify
53
+ Returns:
54
+ np.array(np.uint8): converted array
55
+ """
56
+ if rescale:
57
+ if array.min() < 0:
58
+ if array.min() >= -1 and array.max() <= 1:
59
+ array = (array + 1) / 2
60
+ else:
61
+ raise ValueError(
62
+ f"Data range mismatch for image: ({array.min()}, {array.max()})"
63
+ )
64
+ if array.max() <= 1:
65
+ array = array * 255
66
+ return array.astype(np.uint8)
67
+
68
+
69
+ def resize_and_crop(img, to=640):
70
+ """
71
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
72
+ is `to`, then crops this resized image in its center so that the output is `to x to`
73
+ without aspect ratio distortion
74
+ Args:
75
+ img (np.array): np.uint8 255 image
76
+ Returns:
77
+ np.array: [0, 1] np.float32 image
78
+ """
79
+ # resize keeping aspect ratio: smallest dim is 640
80
+ h, w = img.shape[:2]
81
+ if h < w:
82
+ size = (to, int(to * w / h))
83
+ else:
84
+ size = (int(to * h / w), to)
85
+
86
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
87
+ r_img = uint8(r_img)
88
+
89
+ # crop in the center
90
+ H, W = r_img.shape[:2]
91
+
92
+ top = (H - to) // 2
93
+ left = (W - to) // 2
94
+
95
+ rc_img = r_img[top : top + to, left : left + to, :]
96
+
97
+ return rc_img / 255.0
98
+
99
+
100
+ def to_m1_p1(img):
101
+ """
102
+ rescales a [0, 1] image to [-1, +1]
103
+ Args:
104
+ img (np.array): float32 numpy array of an image in [0, 1]
105
+ i (int): Index of the image being rescaled
106
+ Raises:
107
+ ValueError: If the image is not in [0, 1]
108
+ Returns:
109
+ np.array(np.float32): array in [-1, +1]
110
+ """
111
+ if img.min() >= 0 and img.max() <= 1:
112
+ return (img.astype(np.float32) - 0.5) * 2
113
+ raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
114
+
115
+
116
+ # No need to do any timing in this, since it's just for the HF Space
117
+ class ClimateGAN:
118
+ def __init__(self, model_path) -> None:
119
+ """
120
+ A wrapper for the ClimateGAN model that you can use to generate
121
+ events from images or folders containing images.
122
+
123
+ Args:
124
+ model_path (Union[str, Path]): Where to load the Masker from
125
+ """
126
+ torch.set_grad_enabled(False)
127
+ self.target_size = 640
128
+ self.trainer = Trainer.resume_from_path(
129
+ model_path,
130
+ setup=True,
131
+ inference=True,
132
+ new_exp=None,
133
+ )
134
+ self.trainer.G.half()
135
+ self._stable_diffusion_is_setup = False
136
+
137
+ def _setup_stable_diffusion(self):
138
+ """
139
+ Sets up the stable diffusion pipeline for in-painting.
140
+ Make sure you have accepted the license on the model's card
141
+ https://huggingface.co/CompVis/stable-diffusion-v1-4
142
+ """
143
+ try:
144
+ self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
145
+ "runwayml/stable-diffusion-inpainting",
146
+ revision="fp16",
147
+ torch_dtype=torch.float16,
148
+ safety_checker=None,
149
+ ).to(self.trainer.device)
150
+ self._stable_diffusion_is_setup = True
151
+ except Exception as e:
152
+ print(
153
+ "\nCould not load stable diffusion model. "
154
+ + "Please make sure you have accepted the license on the model's"
155
+ + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
156
+ )
157
+ raise e
158
+
159
+ def _preprocess_image(self, img):
160
+ # rgba to rgb
161
+ data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
162
+
163
+ # to args.target_size
164
+ data = resize_and_crop(data, self.target_size)
165
+
166
+ # resize() produces [0, 1] images, rescale to [-1, 1]
167
+ data = to_m1_p1(data)
168
+ return data
169
+
170
+ # Does all three inferences at the moment.
171
+ def infer_single(
172
+ self,
173
+ orig_image,
174
+ painter="both",
175
+ prompt="An HD picture of a street with dirty water after a heavy flood",
176
+ concats=[
177
+ "input",
178
+ "masked_input",
179
+ "climategan_flood",
180
+ "stable_flood",
181
+ "stable_copy_flood",
182
+ ],
183
+ ):
184
+ """
185
+ Infers the image with the ClimateGAN model.
186
+ Importantly (and unlike self.infer_preprocessed_batch), the image is
187
+ pre-processed by self._preprocess_image before going through the networks.
188
+
189
+ Output dict contains the following keys:
190
+ - "input": The input image
191
+ - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
192
+ - "masked_input": The input image with the mask applied
193
+ - "climategan_flood": The flooded image generated by ClimateGAN's Painter
194
+ on the masked input (only if "painter" is "climategan" or "both").
195
+ - "stable_flood": The flooded image in-painted by the stable diffusion model
196
+ from the mask and the input image (only if "painter" is "stable_diffusion"
197
+ or "both").
198
+ - "stable_copy_flood": The flooded image in-painted by the stable diffusion
199
+ model with its original context pasted back in:
200
+ y = m * flooded + (1-m) * input
201
+ (only if "painter" is "stable_diffusion" or "both").
202
+
203
+ Args:
204
+ orig_image (Union[str, np.array]): image to infer on. Can be a path to
205
+ an image which will be read.
206
+ painter (str, optional): Which painter to use: "climategan",
207
+ "stable_diffusion" or "both". Defaults to "both".
208
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
209
+ to "An HD picture of a street with dirty water after a heavy flood".
210
+ concats (list, optional): List of keys in `output` to concatenate together
211
+ in a new `{original_stem}_concat` image written. Defaults to:
212
+ ["input", "masked_input", "climategan_flood", "stable_flood",
213
+ "stable_copy_flood"].
214
+
215
+ Returns:
216
+ dict: a dictionary containing the output images {k: HxWxC}. C is omitted
217
+ for masks (HxW).
218
+ """
219
+ image_array = (
220
+ np.array(Image.open(orig_image))
221
+ if isinstance(orig_image, str)
222
+ else orig_image
223
+ )
224
+ image = self._preprocess_image(image_array)
225
+ output_dict = self.infer_preprocessed_batch(
226
+ image[None, ...], painter, prompt, concats
227
+ )
228
+ return {k: v[0] for k, v in output_dict.items()}
229
+
230
+ def infer_preprocessed_batch(
231
+ self,
232
+ images,
233
+ painter="both",
234
+ prompt="An HD picture of a street with dirty water after a heavy flood",
235
+ concats=[
236
+ "input",
237
+ "masked_input",
238
+ "climategan_flood",
239
+ "stable_flood",
240
+ "stable_copy_flood",
241
+ ],
242
+ ):
243
+ """
244
+ Infers ClimateGAN predictions on a batch of preprocessed images.
245
+ It assumes that each image in the batch has been preprocessed with
246
+ self._preprocess_image().
247
+
248
+ Output dict contains the following keys:
249
+ - "input": The input image
250
+ - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
251
+ - "masked_input": The input image with the mask applied
252
+ - "climategan_flood": The flooded image generated by ClimateGAN's Painter
253
+ on the masked input (only if "painter" is "climategan" or "both").
254
+ - "stable_flood": The flooded image in-painted by the stable diffusion model
255
+ from the mask and the input image (only if "painter" is "stable_diffusion"
256
+ or "both").
257
+ - "stable_copy_flood": The flooded image in-painted by the stable diffusion
258
+ model with its original context pasted back in:
259
+ y = m * flooded + (1-m) * input
260
+ (only if "painter" is "stable_diffusion" or "both").
261
+
262
+ Args:
263
+ images (np.array): A batch of input images BxHxWx3
264
+ painter (str, optional): Which painter to use: "climategan",
265
+ "stable_diffusion" or "both". Defaults to "both".
266
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
267
+ to "An HD picture of a street with dirty water after a heavy flood".
268
+ concats (list, optional): List of keys in `output` to concatenate together
269
+ in a new `{original_stem}_concat` image written. Defaults to:
270
+ ["input", "masked_input", "climategan_flood", "stable_flood",
271
+ "stable_copy_flood"].
272
+
273
+ Returns:
274
+ dict: a dictionary containing the output images
275
+ """
276
+ assert painter in [
277
+ "both",
278
+ "stable_diffusion",
279
+ "climategan",
280
+ ], f"Unknown painter: {painter}"
281
+
282
+ ignore_event = set()
283
+ if painter == "climategan":
284
+ ignore_event.add("flood")
285
+
286
+ # Retrieve numpy events as a dict {event: array[BxHxWxC]}
287
+ outputs = self.trainer.infer_all(
288
+ images,
289
+ numpy=True,
290
+ bin_value=0.5,
291
+ half=True,
292
+ ignore_event=ignore_event,
293
+ return_masks=True,
294
+ )
295
+
296
+ outputs["input"] = uint8(images, True)
297
+ # from Bx1xHxW to BxHxWx1
298
+ outputs["masked_input"] = outputs["input"] * (
299
+ outputs["mask"].squeeze(1)[..., None] == 0
300
+ )
301
+
302
+ if painter in {"both", "climategan"}:
303
+ outputs["climategan_flood"] = outputs.pop("flood")
304
+ else:
305
+ del outputs["flood"]
306
+
307
+ if painter != "climategan":
308
+ if not self._stable_diffusion_is_setup:
309
+ print("Setting up stable diffusion in-painting pipeline")
310
+ self._setup_stable_diffusion()
311
+
312
+ mask = outputs["mask"].squeeze(1)
313
+ input_images = (
314
+ torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
315
+ )
316
+ input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
317
+ floods = self.sdip_pipeline(
318
+ prompt=[prompt] * images.shape[0],
319
+ image=input_images,
320
+ mask_image=input_mask,
321
+ height=640,
322
+ width=640,
323
+ num_inference_steps=50,
324
+ )
325
+
326
+ bin_mask = mask[..., None] > 0
327
+ flood = np.stack([np.array(i) for i in floods.images])
328
+ copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
329
+ outputs["stable_flood"] = flood
330
+ outputs["stable_copy_flood"] = copy_flood
331
+
332
+ if concats:
333
+ outputs["concat"] = concat_events(outputs, concats, axis=2)
334
+
335
+ return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}
336
+
337
+ def infer_folder(
338
+ self,
339
+ folder_path,
340
+ painter="both",
341
+ prompt="An HD picture of a street with dirty water after a heavy flood",
342
+ batch_size=4,
343
+ concats=[
344
+ "input",
345
+ "masked_input",
346
+ "climategan_flood",
347
+ "stable_flood",
348
+ "stable_copy_flood",
349
+ ],
350
+ write=True,
351
+ overwrite=False,
352
+ ):
353
+ """
354
+ Infers the images in a folder with the ClimateGAN model, batching images for
355
+ inference according to the batch_size.
356
+
357
+ Images must end in .jpg, .jpeg or .png (not case-sensitive).
358
+ Images must not contain the separator ("---") in their name.
359
+
360
+ Images will be written to disk in the same folder as the input images, with
361
+ a name that depends on its data, potentially the prompt and a random
362
+ identifier in case multiple inferences are run in the folder.
363
+
364
+ Output dict contains the following keys:
365
+ - "input": The input image
366
+ - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
367
+ - "masked_input": The input image with the mask applied
368
+ - "climategan_flood": The flooded image generated by ClimateGAN's Painter
369
+ on the masked input (only if "painter" is "climategan" or "both").
370
+ - "stable_flood": The flooded image in-painted by the stable diffusion model
371
+ from the mask and the input image (only if "painter" is "stable_diffusion"
372
+ or "both").
373
+ - "stable_copy_flood": The flooded image in-painted by the stable diffusion
374
+ model with its original context pasted back in:
375
+ y = m * flooded + (1-m) * input
376
+ (only if "painter" is "stable_diffusion" or "both").
377
+
378
+ Args:
379
+ folder_path (Union[str, Path]): Where to read images from.
380
+ painter (str, optional): Which painter to use: "climategan",
381
+ "stable_diffusion" or "both". Defaults to "both".
382
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
383
+ to "An HD picture of a street with dirty water after a heavy flood".
384
+ batch_size (int, optional): Size of inference batches. Defaults to 4.
385
+ concats (list, optional): List of keys in `output` to concatenate together
386
+ in a new `{original_stem}_concat` image written. Defaults to:
387
+ ["input", "masked_input", "climategan_flood", "stable_flood",
388
+ "stable_copy_flood"].
389
+ write (bool, optional): Whether or not to write the outputs to the input
390
+ folder.Defaults to True.
391
+ overwrite (Union[bool, str], optional): Whether to overwrite the images or
392
+ not. If a string is provided, it will be included in the name.
393
+ Defaults to False.
394
+
395
+ Returns:
396
+ dict: a dictionary containing the output images
397
+ """
398
+ folder_path = Path(folder_path).expanduser().resolve()
399
+ assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
400
+ assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
401
+ im_paths = [
402
+ p
403
+ for p in folder_path.iterdir()
404
+ if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
405
+ ]
406
+ assert im_paths, f"No images found in {str(folder_path)}"
407
+ ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
408
+ batches = [
409
+ np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
410
+ ]
411
+ inferences = [
412
+ self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
413
+ ]
414
+
415
+ outputs = {
416
+ k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
417
+ }
418
+
419
+ if write:
420
+ self.write(outputs, im_paths, painter, overwrite, prompt)
421
+
422
+ return outputs
423
+
424
+ def write(
425
+ self,
426
+ outputs,
427
+ im_paths,
428
+ painter="both",
429
+ overwrite=False,
430
+ prompt="",
431
+ ):
432
+ """
433
+ Writes the outputs of the inference to disk, in the input folder.
434
+
435
+ Images will be named like:
436
+ f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
437
+ `painter_type` is either "climategan" or f"stable_diffusion_{prompt}"
438
+
439
+ Args:
440
+ outputs (_type_): The inference procedure's output dict.
441
+ im_paths (list[Path]): The list of input images paths.
442
+ painter (str, optional): Which painter was used. Defaults to "both".
443
+ overwrite (bool, optional): Whether to overwrite the images or not.
444
+ If a string is provided, it will be included in the name.
445
+ If False, a random identifier will be added to the name.
446
+ Defaults to False.
447
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
448
+ to "".
449
+ """
450
+ prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
451
+ overwrite_prefix = ""
452
+ if not overwrite:
453
+ overwrite_prefix = str(uuid4())[:8]
454
+ print("Writing events with prefix", overwrite_prefix)
455
+ else:
456
+ if isinstance(overwrite, str):
457
+ overwrite_prefix = overwrite
458
+ print("Writing events with prefix", overwrite_prefix)
459
+
460
+ # for each image, for each event/data type
461
+ for i, im_path in enumerate(im_paths):
462
+ for event, ims in outputs.items():
463
+ painter_prefix = ""
464
+ if painter == "climategan" and event == "flood":
465
+ painter_prefix = "climategan"
466
+ elif (
467
+ painter in {"stable_diffusion", "both"} and event == "stable_flood"
468
+ ):
469
+ painter_prefix = f"_stable_{prompt}"
470
+ elif painter == "both" and event == "climategan_flood":
471
+ painter_prefix = ""
472
+
473
+ im = ims[i]
474
+ im = Image.fromarray(uint8(im))
475
+ imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
476
+ im.save(im_path.parent / (imstem + im_path.suffix))
inferences.py DELETED
@@ -1,108 +0,0 @@
1
- # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
2
- # thank you @NimaBoscarino
3
-
4
- import torch
5
- from skimage.color import rgba2rgb
6
- from skimage.transform import resize
7
- import numpy as np
8
-
9
- from climategan.trainer import Trainer
10
-
11
-
12
- def uint8(array):
13
- """
14
- convert an array to np.uint8 (does not rescale or anything else than changing dtype)
15
- Args:
16
- array (np.array): array to modify
17
- Returns:
18
- np.array(np.uint8): converted array
19
- """
20
- return array.astype(np.uint8)
21
-
22
-
23
- def resize_and_crop(img, to=640):
24
- """
25
- Resizes an image so that it keeps the aspect ratio and the smallest dimensions
26
- is `to`, then crops this resized image in its center so that the output is `to x to`
27
- without aspect ratio distortion
28
- Args:
29
- img (np.array): np.uint8 255 image
30
- Returns:
31
- np.array: [0, 1] np.float32 image
32
- """
33
- # resize keeping aspect ratio: smallest dim is 640
34
- h, w = img.shape[:2]
35
- if h < w:
36
- size = (to, int(to * w / h))
37
- else:
38
- size = (int(to * h / w), to)
39
-
40
- r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
41
- r_img = uint8(r_img)
42
-
43
- # crop in the center
44
- H, W = r_img.shape[:2]
45
-
46
- top = (H - to) // 2
47
- left = (W - to) // 2
48
-
49
- rc_img = r_img[top : top + to, left : left + to, :]
50
-
51
- return rc_img / 255.0
52
-
53
-
54
- def to_m1_p1(img):
55
- """
56
- rescales a [0, 1] image to [-1, +1]
57
- Args:
58
- img (np.array): float32 numpy array of an image in [0, 1]
59
- i (int): Index of the image being rescaled
60
- Raises:
61
- ValueError: If the image is not in [0, 1]
62
- Returns:
63
- np.array(np.float32): array in [-1, +1]
64
- """
65
- if img.min() >= 0 and img.max() <= 1:
66
- return (img.astype(np.float32) - 0.5) * 2
67
- raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
68
-
69
-
70
- # No need to do any timing in this, since it's just for the HF Space
71
- class ClimateGAN:
72
- def __init__(self, model_path) -> None:
73
- torch.set_grad_enabled(False)
74
- self.target_size = 640
75
- self.trainer = Trainer.resume_from_path(
76
- model_path,
77
- setup=True,
78
- inference=True,
79
- new_exp=None,
80
- )
81
-
82
- # Does all three inferences at the moment.
83
- def inference(self, orig_image):
84
- image = self._preprocess_image(orig_image)
85
-
86
- # Retrieve numpy events as a dict {event: array[BxHxWxC]}
87
- outputs = self.trainer.infer_all(
88
- image,
89
- numpy=True,
90
- bin_value=0.5,
91
- )
92
-
93
- return (
94
- outputs["flood"].squeeze(),
95
- outputs["wildfire"].squeeze(),
96
- outputs["smog"].squeeze(),
97
- )
98
-
99
- def _preprocess_image(self, img):
100
- # rgba to rgb
101
- data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
102
-
103
- # to args.target_size
104
- data = resize_and_crop(data, self.target_size)
105
-
106
- # resize() produces [0, 1] images, rescale to [-1, 1]
107
- data = to_m1_p1(data)
108
- return data