import torch from skimage.color import rgba2rgb from skimage.transform import resize import numpy as np from climategan.trainer import Trainer def uint8(array): """ convert an array to np.uint8 (does not rescale or anything else than changing dtype) Args: array (np.array): array to modify Returns: np.array(np.uint8): converted array """ return array.astype(np.uint8) def resize_and_crop(img, to=640): """ Resizes an image so that it keeps the aspect ratio and the smallest dimensions is `to`, then crops this resized image in its center so that the output is `to x to` without aspect ratio distortion Args: img (np.array): np.uint8 255 image Returns: np.array: [0, 1] np.float32 image """ # resize keeping aspect ratio: smallest dim is 640 h, w = img.shape[:2] if h < w: size = (to, int(to * w / h)) else: size = (int(to * h / w), to) r_img = resize(img, size, preserve_range=True, anti_aliasing=True) r_img = uint8(r_img) # crop in the center H, W = r_img.shape[:2] top = (H - to) // 2 left = (W - to) // 2 rc_img = r_img[top : top + to, left : left + to, :] return rc_img / 255.0 def to_m1_p1(img): """ rescales a [0, 1] image to [-1, +1] Args: img (np.array): float32 numpy array of an image in [0, 1] i (int): Index of the image being rescaled Raises: ValueError: If the image is not in [0, 1] Returns: np.array(np.float32): array in [-1, +1] """ if img.min() >= 0 and img.max() <= 1: return (img.astype(np.float32) - 0.5) * 2 raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") # No need to do any timing in this, since it's just for the HF Space class ClimateGAN(): def __init__(self, model_path) -> None: torch.set_grad_enabled(False) self.target_size = 640 self.trainer = Trainer.resume_from_path( model_path, setup=True, inference=True, new_exp=None, ) # Does all three inferences at the moment. def inference(self, orig_image): image = self._preprocess_image(orig_image) # Retreive numpy events as a dict {event: array[BxHxWxC]} outputs = self.trainer.infer_all( image, numpy=True, bin_value=0.5, ) return ( outputs['flood'].squeeze(), outputs['wildfire'].squeeze(), outputs['smog'].squeeze() ) def _preprocess_image(self, img): # rgba to rgb data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) # to args.target_size data = resize_and_crop(data, self.target_size) # resize() produces [0, 1] images, rescale to [-1, 1] data = to_m1_p1(data) return data