File size: 2,920 Bytes
4756ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c557eb7
4756ce1
 
 
 
 
 
 
 
c557eb7
 
 
 
 
4756ce1
 
 
 
 
 
 
 
 
 
c557eb7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from skimage.color import rgba2rgb
from skimage.transform import resize
import numpy as np

from climategan.trainer import Trainer


def uint8(array):
    """
    convert an array to np.uint8 (does not rescale or anything else than changing dtype)

    Args:
        array (np.array): array to modify

    Returns:
        np.array(np.uint8): converted array
    """
    return array.astype(np.uint8)

def resize_and_crop(img, to=640):
    """
    Resizes an image so that it keeps the aspect ratio and the smallest dimensions
    is `to`, then crops this resized image in its center so that the output is `to x to`
    without aspect ratio distortion

    Args:
        img (np.array): np.uint8 255 image

    Returns:
        np.array: [0, 1] np.float32 image
    """
    # resize keeping aspect ratio: smallest dim is 640
    h, w = img.shape[:2]
    if h < w:
        size = (to, int(to * w / h))
    else:
        size = (int(to * h / w), to)

    r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
    r_img = uint8(r_img)

    # crop in the center
    H, W = r_img.shape[:2]

    top = (H - to) // 2
    left = (W - to) // 2

    rc_img = r_img[top : top + to, left : left + to, :]

    return rc_img / 255.0

def to_m1_p1(img):
    """
    rescales a [0, 1] image to [-1, +1]

    Args:
        img (np.array): float32 numpy array of an image in [0, 1]
        i (int): Index of the image being rescaled

    Raises:
        ValueError: If the image is not in [0, 1]

    Returns:
        np.array(np.float32): array in [-1, +1]
    """
    if img.min() >= 0 and img.max() <= 1:
        return (img.astype(np.float32) - 0.5) * 2
    raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")

# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN():
    def __init__(self, model_path) -> None:
        torch.set_grad_enabled(False)
        self.target_size = 640
        self.trainer = Trainer.resume_from_path(
            model_path,
            setup=True,
            inference=True,
            new_exp=None,
        )

    # Does all three inferences at the moment.
    def inference(self, orig_image):
        image = self._preprocess_image(orig_image)

        # Retreive numpy events as a dict {event: array[BxHxWxC]}
        outputs = self.trainer.infer_all(
            image,
            numpy=True,
            bin_value=0.5,
        )

        return (
            outputs['flood'].squeeze(),
            outputs['wildfire'].squeeze(),
            outputs['smog'].squeeze()
        )

    def _preprocess_image(self, img):
        # rgba to rgb
        data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)

        # to args.target_size
        data = resize_and_crop(data, self.target_size)

        # resize() produces [0, 1] images, rescale to [-1, 1]
        data = to_m1_p1(data)
        return data