#!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-07-12 20:35:28 import math from torch import nn import torch.nn.functional as F class SRCNN(nn.Module): def __init__(self, in_chns, out_chns=None, num_chns=64, depth=8, sf=4): super().__init__() self.sf = sf out_chns = in_chns if out_chns is None else out_chns self.head = nn.Conv2d(in_chns, num_chns, kernel_size=5, padding=2) body = [] for _ in range(depth-1): body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2)) body.append(nn.LeakyReLU(0.2, inplace=True)) self.body = nn.Sequential(*body) tail = [] for _ in range(int(math.log(sf, 2))): tail.append(nn.Conv2d(num_chns, num_chns*4, kernel_size=3, padding=1)) tail.append(nn.LeakyReLU(0.2, inplace=True)) tail.append(nn.PixelShuffle(2)) tail.append(nn.Conv2d(num_chns, out_chns, kernel_size=5, padding=2)) self.tail = nn.Sequential(*tail) def forward(self, x): y = self.head(x) y = self.body(y) y = self.tail(y) return y class SRCNNFSR(nn.Module): def __init__(self, in_chns, down_scale_factor=2, num_chns=64, depth=8, sf=4): super().__init__() self.sf = sf head = [] in_chns_shuffle = in_chns * 4 assert num_chns % 4 == 0 for ii in range(int(math.log(down_scale_factor, 2))): head.append(nn.PixelUnshuffle(2)) head.append(nn.Conv2d(in_chns_shuffle, num_chns, kernel_size=3, padding=1)) if ii + 1 < int(math.log(down_scale_factor, 2)): head.append(nn.Conv2d(num_chns, num_chns//4, kernel_size=5, padding=2)) head.append(nn.LeakyReLU(0.2, inplace=True)) in_chns_shuffle = num_chns self.head = nn.Sequential(*head) body = [] for _ in range(depth-1): body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2)) body.append(nn.LeakyReLU(0.2, inplace=True)) self.body = nn.Sequential(*body) tail = [] for _ in range(int(math.log(down_scale_factor, 2))): tail.append(nn.Conv2d(num_chns, num_chns, kernel_size=3, padding=1)) tail.append(nn.LeakyReLU(0.2, inplace=True)) tail.append(nn.PixelShuffle(2)) num_chns //= 4 tail.append(nn.Conv2d(num_chns, in_chns, kernel_size=5, padding=2)) self.tail = nn.Sequential(*tail) def forward(self, x): y = self.head(x) y = self.body(y) y = self.tail(y) return y