namnguyen2103's picture
VDSR added
a522864 verified
raw
history blame
2.52 kB
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor
from PIL import Image
import os
from math import sqrt
import torch.nn.functional as F
#define class Block contain conv and relu layer
class Block(nn.Module):
def __init__(self):
super(Block, self).__init__()
self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.conv(x))
class VDSR(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_blocks=18):
super(VDSR, self).__init__()
self.residual_layer = self.make_layer(Block, num_blocks)
self.input = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.output = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, sqrt(2. / n))
def make_layer(self, block, num_layers):
layers=[]
for _ in range(num_layers):
layers.append(block())
return nn.Sequential(*layers)
def forward(self, x):
residual = x
out = self.relu(self.input(x))
out = self.residual_layer(out)
out = self.output(out)
out = torch.add(residual, out)
return out
def inference(self, x):
"""
x is a PIL image
"""
self.eval()
with torch.no_grad():
x = ToTensor()(x).unsqueeze(0)
x = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False).clamp(0, 1)
x = self.forward(x).clamp(0, 1)
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
return x
if __name__ == '__main__':
current_dir = os.path.dirname(os.path.realpath(__file__))
model = torch.load(current_dir + '/vdsr_checkpoint.pth', map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
input_image = Image.open('images/demo.png')
output_image = model.inference(input_image)
print(input_image.size, output_image.size)