vict0rsch commited on
Commit
5e61b3b
1 Parent(s): 7e7950a

add inference script features

Browse files
Files changed (1) hide show
  1. climategan_wrapper.py +82 -1
climategan_wrapper.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import re
6
  from pathlib import Path
7
  from uuid import uuid4
8
-
9
  import numpy as np
10
  import torch
11
  from diffusers import StableDiffusionInpaintPipeline
@@ -541,3 +541,84 @@ class ClimateGAN:
541
  im = Image.fromarray(uint8(im))
542
  imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
543
  im.save(im_path.parent / (imstem + im_path.suffix))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import re
6
  from pathlib import Path
7
  from uuid import uuid4
8
+ from minydra import resolved_args
9
  import numpy as np
10
  import torch
11
  from diffusers import StableDiffusionInpaintPipeline
 
541
  im = Image.fromarray(uint8(im))
542
  imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
543
  im.save(im_path.parent / (imstem + im_path.suffix))
544
+
545
+
546
+ if __name__ == "__main__":
547
+ print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
548
+
549
+ # parse arguments
550
+ args = resolved_args(
551
+ defaults={
552
+ "input_folder": None,
553
+ "output_folder": None,
554
+ "painter": "both",
555
+ "help": False,
556
+ }
557
+ )
558
+
559
+ # print help
560
+ if args.help:
561
+ print(
562
+ "Usage: python inference.py input_folder=/path/to/folder\n"
563
+ + "By default inferences will be stored in the input folder.\n"
564
+ + "Add `output_folder=/path/to/folder` for a different output folder.\n"
565
+ + "By default, both ClimateGAN and Stable Diffusion will be used."
566
+ + "Change this by adding `painter=climategan` or"
567
+ + " `painter=stable_diffusion`.\n"
568
+ + "Make sure you have agreed to the terms of use for the models."
569
+ + "In particular, visit SD's model card to agree to the terms of use:"
570
+ + " https://huggingface.co/runwayml/stable-diffusion-inpainting"
571
+ )
572
+ # print args
573
+ args.pretty_print()
574
+
575
+ # load models
576
+ cg = ClimateGAN("models/climategan")
577
+
578
+ # check painter type
579
+ assert args.painter in {"climategan", "stable_diffusion", "both",}, (
580
+ f"Unknown painter {args.painter}. "
581
+ + "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
582
+ )
583
+
584
+ # load SD pipeline if need be
585
+ if args.painter != "climate_gan":
586
+ cg._setup_stable_diffusion()
587
+
588
+ # resolve input folder path
589
+ in_path = Path(args.input_folder).expanduser().resolve()
590
+ assert in_path.exists(), f"Folder {str(in_path)} does not exist"
591
+
592
+ # output is input if not specified
593
+ if args.output_folder is None:
594
+ out_path = in_path
595
+
596
+ # find images in input folder
597
+ im_paths = [
598
+ p
599
+ for p in in_path.iterdir()
600
+ if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
601
+ ]
602
+ assert im_paths, f"No images found in {str(im_paths)}"
603
+
604
+ print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
605
+
606
+ # infer and write
607
+ for i, im_path in enumerate(im_paths):
608
+ print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
609
+ outs = cg.infer_single(
610
+ np.array(Image.open(im_path)),
611
+ args.painter,
612
+ as_pil_image=True,
613
+ concats=[
614
+ "input",
615
+ "masked_input",
616
+ "climategan_flood",
617
+ "stable_copy_flood",
618
+ ],
619
+ )
620
+ for k, v in outs.items():
621
+ name = f"{im_path.stem}---{k}{im_path.suffix}"
622
+ im = Image.fromarray(uint8(v))
623
+ im.save(out_path / name)
624
+ print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")