File size: 4,503 Bytes
6e601ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn.functional as F
import random
import kornia
from torchvision.transforms.functional import adjust_brightness, adjust_contrast

from climategan.tutils import normalize, retrieve_sky_mask

try:
    from kornia.filters import filter2d
except ImportError:
    from kornia.filters import filter2D as filter2d


def increase_sky_mask(mask, p_w=0, p_h=0):
    """
    Increases sky mask in width and height by a given pourcentage
    (Purpose: when applying Gaussian blur, there are no artifacts of blue sky behind)
    Args:
        sky_mask (torch.Tensor): Sky mask of shape (H,W)
        p_w (float): Percentage of mask width by which to increase
            the width of the sky region
        p_h (float): Percentage of mask height by which to increase
            the height of the sky region
    Returns:
        torch.Tensor: Sky mask increased given p_w and p_h
    """

    if p_h <= 0 and p_w <= 0:
        return mask

    n_lines = int(p_h * mask.shape[-2])
    n_cols = int(p_w * mask.shape[-1])

    temp_mask = mask.clone().detach()
    for i in range(1, n_cols):
        temp_mask[:, :, :, i::] += mask[:, :, :, 0:-i]
        temp_mask[:, :, :, 0:-i] += mask[:, :, :, i::]

    new_mask = temp_mask.clone().detach()
    for i in range(1, n_lines):
        new_mask[:, :, i::, :] += temp_mask[:, :, 0:-i, :]
        new_mask[:, :, 0:-i, :] += temp_mask[:, :, i::, :]

    new_mask[new_mask >= 1] = 1

    return new_mask


def paste_filter(x, filter_, mask):
    """
    Pastes a filter over an image given a mask
    Where the mask is 1, the filter is copied as is.
    Where the mask is 0, the current value is preserved.
    Intermediate values will mix the two images together.
    Args:
        x (torch.Tensor): Input tensor, range must be [0, 255]
        filer_ (torch.Tensor): Filter, range must be [0, 255]
        mask (torch.Tensor): Mask, range must be [0, 1]
    Returns:
        torch.Tensor: New tensor with filter pasted on it
    """
    assert len(x.shape) == len(filter_.shape) == len(mask.shape)
    x = filter_ * mask + x * (1 - mask)
    return x


def add_fire(x, seg_preds, fire_opts):
    """
    Transforms input tensor given wildfires event
    Args:
        x (torch.Tensor): Input tensor
        seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor
        filter_color (tuple): (r,g,b) tuple for the color of the sky
        blur_radius (float): radius of the Gaussian blur that smooths
            the transition between sky and foreground
    Returns:
        torch.Tensor: Wildfire version of input tensor
    """
    wildfire_tens = normalize(x, 0, 255)

    # Warm the image
    wildfire_tens[:, 2, :, :] -= 20
    wildfire_tens[:, 1, :, :] -= 10
    wildfire_tens[:, 0, :, :] += 40
    wildfire_tens.clamp_(0, 255)
    wildfire_tens = wildfire_tens.to(torch.uint8)

    # Darken the picture and increase contrast
    wildfire_tens = adjust_contrast(wildfire_tens, contrast_factor=1.5)
    wildfire_tens = adjust_brightness(wildfire_tens, brightness_factor=0.73)

    sky_mask = retrieve_sky_mask(seg_preds).unsqueeze(1)

    if fire_opts.get("crop_bottom_sky_mask"):
        i = 2 * sky_mask.shape[-2] // 3
        sky_mask[..., i:, :] = 0

    sky_mask = F.interpolate(
        sky_mask.to(torch.float),
        (wildfire_tens.shape[-2], wildfire_tens.shape[-1]),
    )
    sky_mask = increase_sky_mask(sky_mask, 0.18, 0.18)

    kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301))
    sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5))
    border_type = "reflect"
    kernel = torch.unsqueeze(
        kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0
    ).to(x.device)
    sky_mask = filter2d(sky_mask, kernel, border_type)

    filter_ = torch.ones(wildfire_tens.shape, device=x.device)
    filter_[:, 0, :, :] = 255
    filter_[:, 1, :, :] = random.randint(100, 150)
    filter_[:, 2, :, :] = 0

    wildfire_tens = paste_tensor(wildfire_tens, filter_, sky_mask, 200)

    wildfire_tens = adjust_brightness(wildfire_tens.to(torch.uint8), 0.8)
    wildfire_tens = wildfire_tens.to(torch.float)

    # dummy pixels to fool scaling and preserve range
    wildfire_tens[:, :, 0, 0] = 255.0
    wildfire_tens[:, :, -1, -1] = 0.0

    return wildfire_tens


def paste_tensor(source, filter_, mask, transparency):
    mask = transparency / 255.0 * mask
    new = mask * filter_ + (1.0 - mask) * source
    return new