File size: 2,523 Bytes
a522864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor
from PIL import Image
import os
from math import sqrt
import torch.nn.functional as F

#define class Block contain conv and relu layer
class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.relu(self.conv(x))

class VDSR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_blocks=18):
        super(VDSR, self).__init__()
        self.residual_layer = self.make_layer(Block, num_blocks)
        self.input = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.output = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                          
    def make_layer(self, block, num_layers):
        layers=[]
        for _ in range(num_layers):
            layers.append(block())
        return nn.Sequential(*layers)
    
    def forward(self, x):
        residual = x
        out = self.relu(self.input(x))
        out = self.residual_layer(out)
        out = self.output(out)
        out = torch.add(residual, out)
        return out
    
    def inference(self, x):
        """

        x is a PIL image

        """
        self.eval()
        with torch.no_grad():
            x = ToTensor()(x).unsqueeze(0)
            x = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False).clamp(0, 1)
            x = self.forward(x).clamp(0, 1)
            x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
        return x

if __name__ == '__main__':
    current_dir = os.path.dirname(os.path.realpath(__file__))

    model = torch.load(current_dir + '/vdsr_checkpoint.pth', map_location=torch.device('cpu'))
    model.eval()
    with torch.no_grad():
        input_image = Image.open('images/demo.png')
        output_image = model.inference(input_image) 
    print(input_image.size, output_image.size)