File size: 3,545 Bytes
0fe2a53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Union
import cv2
import torch
import numpy as np
from torch import nn
from torchvision import transforms as T


class SRCNN(nn.Module):
    def __init__(

        self,

        input_channels=3,

        output_channels=3,

        input_size=33,

        label_size=21,

        scale=2,

        device=None,

    ):
        super().__init__()
        self.input_size = input_size
        self.label_size = label_size
        self.pad = (self.input_size - self.label_size) // 2
        self.scale = scale
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, 9),
            nn.ReLU(),
            nn.Conv2d(64, 32, 1),
            nn.ReLU(),
            nn.Conv2d(32, output_channels, 5),
            nn.ReLU(),
        )
        self.transform = T.Compose(
            [T.ToTensor()]  # Scale between [0, 1]
        )

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    @torch.no_grad()
    def pre_process(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
        if torch.is_tensor(x):
            return x / 255.0
        else:
            return self.transform(x)

    @torch.no_grad()
    def post_process(self, x: torch.Tensor) -> torch.Tensor:
        return x.clip(0, 1) * 255.0

    @torch.no_grad()
    def enhance(self, image: np.ndarray, outscale: float = 2) -> np.ndarray:
        (h, w) = image.shape[:2]
        scale_w = int((w - w % self.label_size + self.input_size) * self.scale)
        scale_h = int((h - h % self.label_size + self.input_size) * self.scale)
        # resize the input image using bicubic interpolation
        scaled = cv2.resize(image, (scale_w, scale_h), interpolation=cv2.INTER_CUBIC)
        # Preprocessing
        in_tensor = self.pre_process(scaled)  # (C, H, W)
        out_tensor = torch.zeros_like(in_tensor)  # (C, H, W)

        # slide a window from left-to-right and top-to-bottom
        for y in range(0, scale_h - self.input_size + 1, self.label_size):
            for x in range(0, scale_w - self.input_size + 1, self.label_size):
                # crop ROI from our scaled image
                crop = in_tensor[:, y : y + self.input_size, x : x + self.input_size]
                # make a prediction on the crop and store it in our output
                crop_inp = crop.unsqueeze(0).to(self.device)
                pred = self.forward(crop_inp).cpu().squeeze()
                out_tensor[
                    :,
                    y + self.pad : y + self.pad + self.label_size,
                    x + self.pad : x + self.pad + self.label_size,
                ] = pred

        out_tensor = self.post_process(out_tensor)
        output = out_tensor.permute(1, 2, 0).numpy()  # (C, H, W) to (H, W, C)
        output = output[self.pad : -self.pad * 2, self.pad : -self.pad * 2]
        output = np.clip(output, 0, 255).astype("uint8")

        # Use openCV to upsample image if scaling factor different than 2
        if outscale != 2:
            interpolation = cv2.INTER_AREA if outscale < 2 else cv2.INTER_LANCZOS4
            h, w = output.shape[0:2]
            output = cv2.resize(
                output,
                (int(w * outscale / 2), int(h * outscale / 2)),
                interpolation=interpolation,
            )

        return output, None