climategan / inferences.py
NimaBoscarino's picture
Refactor, polish (WIP)
4756ce1
raw
history blame
2.95 kB
import torch
from skimage.color import rgba2rgb
from skimage.transform import resize
import numpy as np
from climategan.trainer import Trainer
def uint8(array):
"""
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
Args:
array (np.array): array to modify
Returns:
np.array(np.uint8): converted array
"""
return array.astype(np.uint8)
def resize_and_crop(img, to=640):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is `to`, then crops this resized image in its center so that the output is `to x to`
without aspect ratio distortion
Args:
img (np.array): np.uint8 255 image
Returns:
np.array: [0, 1] np.float32 image
"""
# resize keeping aspect ratio: smallest dim is 640
h, w = img.shape[:2]
if h < w:
size = (to, int(to * w / h))
else:
size = (int(to * h / w), to)
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
# crop in the center
H, W = r_img.shape[:2]
top = (H - to) // 2
left = (W - to) // 2
rc_img = r_img[top : top + to, left : left + to, :]
return rc_img / 255.0
def to_m1_p1(img):
"""
rescales a [0, 1] image to [-1, +1]
Args:
img (np.array): float32 numpy array of an image in [0, 1]
i (int): Index of the image being rescaled
Raises:
ValueError: If the image is not in [0, 1]
Returns:
np.array(np.float32): array in [-1, +1]
"""
if img.min() >= 0 and img.max() <= 1:
return (img.astype(np.float32) - 0.5) * 2
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN():
def __init__(self, model_path) -> None:
torch.set_grad_enabled(False)
self.target_size = 640
self.trainer = Trainer.resume_from_path(
model_path,
setup=True,
inference=True,
new_exp=None,
)
# Does all three inferences at the moment.
def inference(self, orig_image):
image, new_size = self._preprocess_image(orig_image)
image = np.stack(image)
# Retreive numpy events as a dict {event: array[BxHxWxC]}
outputs = self.trainer.infer_all(
image,
numpy=True,
bin_value=0.5,
)
return outputs['flood'], outputs['wildfire'], outputs['smog']
def _preprocess_image(self, img):
# rgba to rgb
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
# to args.target_size
data = resize_and_crop(data, self.target_size)
new_size = (self.target_size, self.target_size)
# resize() produces [0, 1] images, rescale to [-1, 1]
data = to_m1_p1(data)
return data, new_size