climategan / inferences.py
NimaBoscarino's picture
Fix image output dims
c557eb7
raw history blame
No virus
2.92 kB
import torch
from skimage.color import rgba2rgb
from skimage.transform import resize
import numpy as np
from climategan.trainer import Trainer
def uint8(array):
"""
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
Args:
array (np.array): array to modify
Returns:
np.array(np.uint8): converted array
"""
return array.astype(np.uint8)
def resize_and_crop(img, to=640):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is `to`, then crops this resized image in its center so that the output is `to x to`
without aspect ratio distortion
Args:
img (np.array): np.uint8 255 image
Returns:
np.array: [0, 1] np.float32 image
"""
# resize keeping aspect ratio: smallest dim is 640
h, w = img.shape[:2]
if h < w:
size = (to, int(to * w / h))
else:
size = (int(to * h / w), to)
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
# crop in the center
H, W = r_img.shape[:2]
top = (H - to) // 2
left = (W - to) // 2
rc_img = r_img[top : top + to, left : left + to, :]
return rc_img / 255.0
def to_m1_p1(img):
"""
rescales a [0, 1] image to [-1, +1]
Args:
img (np.array): float32 numpy array of an image in [0, 1]
i (int): Index of the image being rescaled
Raises:
ValueError: If the image is not in [0, 1]
Returns:
np.array(np.float32): array in [-1, +1]
"""
if img.min() >= 0 and img.max() <= 1:
return (img.astype(np.float32) - 0.5) * 2
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN():
def __init__(self, model_path) -> None:
torch.set_grad_enabled(False)
self.target_size = 640
self.trainer = Trainer.resume_from_path(
model_path,
setup=True,
inference=True,
new_exp=None,
)
# Does all three inferences at the moment.
def inference(self, orig_image):
image = self._preprocess_image(orig_image)
# Retreive numpy events as a dict {event: array[BxHxWxC]}
outputs = self.trainer.infer_all(
image,
numpy=True,
bin_value=0.5,
)
return (
outputs['flood'].squeeze(),
outputs['wildfire'].squeeze(),
outputs['smog'].squeeze()
)
def _preprocess_image(self, img):
# rgba to rgb
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
# to args.target_size
data = resize_and_crop(data, self.target_size)
# resize() produces [0, 1] images, rescale to [-1, 1]
data = to_m1_p1(data)
return data