# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501 # thank you @NimaBoscarino import os import re from pathlib import Path from uuid import uuid4 from minydra import resolved_args import numpy as np import torch from diffusers import StableDiffusionInpaintPipeline from PIL import Image from skimage.color import rgba2rgb from skimage.transform import resize from climategan.trainer import Trainer CUDA = torch.cuda.is_available() def concat_events(output_dict, events, i=None, axis=1): """ Concatenates the `i`th data in `output_dict` according to the keys listed in `events` on dimension `axis`. Args: output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping events to their corresponding data : {k: [HxWxC]} (for i != None) or {k: BxHxWxC}. events (list[str]): output_dict's keys to concatenate. axis (int, optional): Concatenation axis. Defaults to 1. """ cs = [e for e in events if e in output_dict] if i is not None: return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis)) return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis)) def clear(folder): """ Deletes all the images without the inference separator "---" in their name. Args: folder (Union[str, Path]): The folder to clear. """ for i in list(Path(folder).iterdir()): if i.is_file() and "---" in i.stem: i.unlink() def uint8(array, rescale=False): """ convert an array to np.uint8 (does not rescale or anything else than changing dtype) Args: array (np.array): array to modify Returns: np.array(np.uint8): converted array """ if rescale: if array.min() < 0: if array.min() >= -1 and array.max() <= 1: array = (array + 1) / 2 else: raise ValueError( f"Data range mismatch for image: ({array.min()}, {array.max()})" ) if array.max() <= 1: array = array * 255 return array.astype(np.uint8) def resize_and_crop(img, to=640): """ Resizes an image so that it keeps the aspect ratio and the smallest dimensions is `to`, then crops this resized image in its center so that the output is `to x to` without aspect ratio distortion Args: img (np.array): np.uint8 255 image Returns: np.array: [0, 1] np.float32 image """ # resize keeping aspect ratio: smallest dim is 640 h, w = img.shape[:2] if h < w: size = (to, int(to * w / h)) else: size = (int(to * h / w), to) r_img = resize(img, size, preserve_range=True, anti_aliasing=True) r_img = uint8(r_img) # crop in the center H, W = r_img.shape[:2] top = (H - to) // 2 left = (W - to) // 2 rc_img = r_img[top : top + to, left : left + to, :] return rc_img / 255.0 def to_m1_p1(img): """ rescales a [0, 1] image to [-1, +1] Args: img (np.array): float32 numpy array of an image in [0, 1] i (int): Index of the image being rescaled Raises: ValueError: If the image is not in [0, 1] Returns: np.array(np.float32): array in [-1, +1] """ if img.min() >= 0 and img.max() <= 1: return (img.astype(np.float32) - 0.5) * 2 raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") # No need to do any timing in this, since it's just for the HF Space class ClimateGAN: def __init__(self, model_path, dev_mode=False) -> None: """ A wrapper for the ClimateGAN model that you can use to generate events from images or folders containing images. Args: model_path (Union[str, Path]): Where to load the Masker from """ torch.set_grad_enabled(False) self.target_size = 640 self._stable_diffusion_is_setup = False self.dev_mode = dev_mode if self.dev_mode: return self.trainer = Trainer.resume_from_path( model_path, setup=True, inference=True, new_exp=None, ) if CUDA: self.trainer.G.half() def _setup_stable_diffusion(self): """ Sets up the stable diffusion pipeline for in-painting. Make sure you have accepted the license on the model's card https://huggingface.co/CompVis/stable-diffusion-v1-4 """ if self.dev_mode: return try: self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16" if CUDA else "main", torch_dtype=torch.float16 if CUDA else torch.float32, safety_checker=None, use_auth_token=os.environ.get("HF_AUTH_TOKEN"), ).to(self.trainer.device) self._stable_diffusion_is_setup = True except Exception as e: print( "\nCould not load stable diffusion model. " + "Please make sure you have accepted the license on the model's" + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n" ) raise e def _preprocess_image(self, img): """ Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array in [-1, 1]. Args: img (np.array): Image to resize crop and rescale Returns: np.array: Resized, cropped and rescaled image """ # rgba to rgb data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) # to args.target_size data = resize_and_crop(data, self.target_size) # resize() produces [0, 1] images, rescale to [-1, 1] data = to_m1_p1(data) return data # Does all three inferences at the moment. def infer_single( self, orig_image, painter="both", prompt="An HD picture of a street with dirty water after a heavy flood", concats=[ "input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood", ], as_pil_image=False, ): """ Infers the image with the ClimateGAN model. Importantly (and unlike self.infer_preprocessed_batch), the image is pre-processed by self._preprocess_image before going through the networks. Output dict contains the following keys: - "input": The input image - "mask": The mask used to generate the flood (from ClimateGAN's Masker) - "masked_input": The input image with the mask applied - "climategan_flood": The flooded image generated by ClimateGAN's Painter on the masked input (only if "painter" is "climategan" or "both"). - "stable_flood": The flooded image in-painted by the stable diffusion model from the mask and the input image (only if "painter" is "stable_diffusion" or "both"). - "stable_copy_flood": The flooded image in-painted by the stable diffusion model with its original context pasted back in: y = m * flooded + (1-m) * input (only if "painter" is "stable_diffusion" or "both"). Args: orig_image (Union[str, np.array]): image to infer on. Can be a path to an image which will be read. painter (str, optional): Which painter to use: "climategan", "stable_diffusion" or "both". Defaults to "both". prompt (str, optional): The prompt used to guide the diffusion. Defaults to "An HD picture of a street with dirty water after a heavy flood". concats (list, optional): List of keys in `output` to concatenate together in a new `{original_stem}_concat` image written. Defaults to: ["input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood"]. Returns: dict: a dictionary containing the output images {k: HxWxC}. C is omitted for masks (HxW). """ if self.dev_mode: return { "input": orig_image, "mask": np.random.randint(0, 255, (640, 640)), "masked_input": np.random.randint(0, 255, (640, 640, 3)), "climategan_flood": np.random.randint(0, 255, (640, 640, 3)), "stable_flood": np.random.randint(0, 255, (640, 640, 3)), "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)), "concat": np.random.randint(0, 255, (640, 640 * 5, 3)), "smog": np.random.randint(0, 255, (640, 640, 3)), "wildfire": np.random.randint(0, 255, (640, 640, 3)), "depth": np.random.randint(0, 255, (640, 640, 1)), "segmentation": np.random.randint(0, 255, (640, 640, 3)), } return image_array = ( np.array(Image.open(orig_image)) if isinstance(orig_image, str) else orig_image ) pil_image = None if as_pil_image: pil_image = Image.fromarray(image_array) print("Preprocessing image") image = self._preprocess_image(image_array) output_dict = self.infer_preprocessed_batch( images=image[None, ...], painter=painter, prompt=prompt, concats=concats, pil_image=pil_image, ) print("Inference done") return {k: v[0] for k, v in output_dict.items()} def infer_preprocessed_batch( self, images, painter="both", prompt="An HD picture of a street with dirty water after a heavy flood", concats=[ "input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood", ], pil_image=None, ): """ Infers ClimateGAN predictions on a batch of preprocessed images. It assumes that each image in the batch has been preprocessed with self._preprocess_image(). Output dict contains the following keys: - "input": The input image - "mask": The mask used to generate the flood (from ClimateGAN's Masker) - "masked_input": The input image with the mask applied - "climategan_flood": The flooded image generated by ClimateGAN's Painter on the masked input (only if "painter" is "climategan" or "both"). - "stable_flood": The flooded image in-painted by the stable diffusion model from the mask and the input image (only if "painter" is "stable_diffusion" or "both"). - "stable_copy_flood": The flooded image in-painted by the stable diffusion model with its original context pasted back in: y = m * flooded + (1-m) * input (only if "painter" is "stable_diffusion" or "both"). Args: images (np.array): A batch of input images BxHxWx3 painter (str, optional): Which painter to use: "climategan", "stable_diffusion" or "both". Defaults to "both". prompt (str, optional): The prompt used to guide the diffusion. Defaults to "An HD picture of a street with dirty water after a heavy flood". concats (list, optional): List of keys in `output` to concatenate together in a new `{original_stem}_concat` image written. Defaults to: ["input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood"]. pil_image (PIL.Image, optional): The original PIL image. If provided, will be used for a single inference (batch_size=1) Returns: dict: a dictionary containing the output images """ assert painter in [ "both", "stable_diffusion", "climategan", ], f"Unknown painter: {painter}" ignore_event = set() if painter == "stable_diffusion": ignore_event.add("flood") if pil_image is not None: print("Warning: `pil_image` has been provided, it will override `images`") images = self._preprocess_image(np.array(pil_image))[None, ...] pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8)) # Retrieve numpy events as a dict {event: array[BxHxWxC]} print("Inferring ClimateGAN events") outputs = self.trainer.infer_all( images, numpy=True, bin_value=0.5, half=CUDA, ignore_event=ignore_event, return_intermediates=True, ) outputs["input"] = uint8(images, True) # from Bx1xHxW to BxHxWx1 outputs["masked_input"] = outputs["input"] * ( outputs["mask"].squeeze(1)[..., None] == 0 ) if painter in {"both", "climategan"}: outputs["climategan_flood"] = outputs.pop("flood") else: del outputs["flood"] if painter != "climategan": if not self._stable_diffusion_is_setup: print("Setting up stable diffusion in-painting pipeline") self._setup_stable_diffusion() mask = outputs["mask"].squeeze(1) input_images = ( torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device) if pil_image is None else pil_image ) input_mask = ( torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device) if pil_image is None else Image.fromarray(mask[0]) ) print("Inferring stable diffusion in-painting for 50 steps") floods = self.sdip_pipeline( prompt=[prompt] * images.shape[0], image=input_images, mask_image=input_mask, height=640, width=640, num_inference_steps=50, ) print("Stable diffusion in-painting done") bin_mask = mask[..., None] > 0 flood = np.stack([np.array(i) for i in floods.images]) copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask) outputs["stable_flood"] = flood outputs["stable_copy_flood"] = copy_flood if concats: print("Concatenating flood images") outputs["concat"] = concat_events(outputs, concats, axis=2) return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()} def infer_folder( self, folder_path, painter="both", prompt="An HD picture of a street with dirty water after a heavy flood", batch_size=4, concats=[ "input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood", ], write=True, overwrite=False, ): """ Infers the images in a folder with the ClimateGAN model, batching images for inference according to the batch_size. Images must end in .jpg, .jpeg or .png (not case-sensitive). Images must not contain the separator ("---") in their name. Images will be written to disk in the same folder as the input images, with a name that depends on its data, potentially the prompt and a random identifier in case multiple inferences are run in the folder. Output dict contains the following keys: - "input": The input image - "mask": The mask used to generate the flood (from ClimateGAN's Masker) - "masked_input": The input image with the mask applied - "climategan_flood": The flooded image generated by ClimateGAN's Painter on the masked input (only if "painter" is "climategan" or "both"). - "stable_flood": The flooded image in-painted by the stable diffusion model from the mask and the input image (only if "painter" is "stable_diffusion" or "both"). - "stable_copy_flood": The flooded image in-painted by the stable diffusion model with its original context pasted back in: y = m * flooded + (1-m) * input (only if "painter" is "stable_diffusion" or "both"). Args: folder_path (Union[str, Path]): Where to read images from. painter (str, optional): Which painter to use: "climategan", "stable_diffusion" or "both". Defaults to "both". prompt (str, optional): The prompt used to guide the diffusion. Defaults to "An HD picture of a street with dirty water after a heavy flood". batch_size (int, optional): Size of inference batches. Defaults to 4. concats (list, optional): List of keys in `output` to concatenate together in a new `{original_stem}_concat` image written. Defaults to: ["input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood"]. write (bool, optional): Whether or not to write the outputs to the input folder.Defaults to True. overwrite (Union[bool, str], optional): Whether to overwrite the images or not. If a string is provided, it will be included in the name. Defaults to False. Returns: dict: a dictionary containing the output images """ folder_path = Path(folder_path).expanduser().resolve() assert folder_path.exists(), f"Folder {str(folder_path)} does not exist" assert folder_path.is_dir(), f"{str(folder_path)} is not a directory" im_paths = [ p for p in folder_path.iterdir() if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name ] assert im_paths, f"No images found in {str(folder_path)}" ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths] batches = [ np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size) ] inferences = [ self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches ] outputs = { k: [i for e in inferences for i in e[k]] for k in inferences[0].keys() } if write: self.write(outputs, im_paths, painter, overwrite, prompt) return outputs def write( self, outputs, im_paths, painter="both", overwrite=False, prompt="", ): """ Writes the outputs of the inference to disk, in the input folder. Images will be named like: f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}" `painter_type` is either "climategan" or f"stable_diffusion_{prompt}" Args: outputs (_type_): The inference procedure's output dict. im_paths (list[Path]): The list of input images paths. painter (str, optional): Which painter was used. Defaults to "both". overwrite (bool, optional): Whether to overwrite the images or not. If a string is provided, it will be included in the name. If False, a random identifier will be added to the name. Defaults to False. prompt (str, optional): The prompt used to guide the diffusion. Defaults to "". """ prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower() overwrite_prefix = "" if not overwrite: overwrite_prefix = str(uuid4())[:8] print("Writing events with prefix", overwrite_prefix) else: if isinstance(overwrite, str): overwrite_prefix = overwrite print("Writing events with prefix", overwrite_prefix) # for each image, for each event/data type for i, im_path in enumerate(im_paths): for event, ims in outputs.items(): painter_prefix = "" if painter == "climategan" and event == "flood": painter_prefix = "climategan" elif ( painter in {"stable_diffusion", "both"} and event == "stable_flood" ): painter_prefix = f"_stable_{prompt}" elif painter == "both" and event == "climategan_flood": painter_prefix = "" im = ims[i] im = Image.fromarray(uint8(im)) imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}" im.save(im_path.parent / (imstem + im_path.suffix)) if __name__ == "__main__": print("Run `$ python climategan_wrapper.py help` for usage instructions\n") # parse arguments args = resolved_args( defaults={ "input_folder": None, "output_folder": None, "painter": "both", "help": False, } ) # print help if args.help: print( "Usage: python inference.py input_folder=/path/to/folder\n" + "By default inferences will be stored in the input folder.\n" + "Add `output_folder=/path/to/folder` for a different output folder.\n" + "By default, both ClimateGAN and Stable Diffusion will be used." + "Change this by adding `painter=climategan` or" + " `painter=stable_diffusion`.\n" + "Make sure you have agreed to the terms of use for the models." + "In particular, visit SD's model card to agree to the terms of use:" + " https://huggingface.co/runwayml/stable-diffusion-inpainting" ) # print args args.pretty_print() # load models cg = ClimateGAN("models/climategan") # check painter type assert args.painter in { "climategan", "stable_diffusion", "both", }, ( f"Unknown painter {args.painter}. " + "Allowed values are 'climategan', 'stable_diffusion' and 'both'." ) # load SD pipeline if need be if args.painter != "climate_gan": cg._setup_stable_diffusion() # resolve input folder path in_path = Path(args.input_folder).expanduser().resolve() assert in_path.exists(), f"Folder {str(in_path)} does not exist" # output is input if not specified if args.output_folder is None: out_path = in_path # find images in input folder im_paths = [ p for p in in_path.iterdir() if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name ] assert im_paths, f"No images found in {str(im_paths)}" print(f"\nFound {len(im_paths)} images in {str(in_path)}\n") # infer and write for i, im_path in enumerate(im_paths): print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name) outs = cg.infer_single( np.array(Image.open(im_path)), args.painter, as_pil_image=True, concats=[ "input", "masked_input", "climategan_flood", "stable_copy_flood", ], ) for k, v in outs.items(): name = f"{im_path.stem}---{k}{im_path.suffix}" im = Image.fromarray(uint8(v)) im.save(out_path / name) print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")