File size: 2,653 Bytes
06f26d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-07-12 20:35:28

import math
from torch import nn
import torch.nn.functional as F

class SRCNN(nn.Module):
    def __init__(self, in_chns, out_chns=None, num_chns=64, depth=8, sf=4):
        super().__init__()
        self.sf = sf
        out_chns = in_chns if out_chns is None else out_chns

        self.head = nn.Conv2d(in_chns, num_chns, kernel_size=5, padding=2)

        body = []
        for _ in range(depth-1):
            body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2))
            body.append(nn.LeakyReLU(0.2, inplace=True))
        self.body = nn.Sequential(*body)

        tail = []
        for _ in range(int(math.log(sf, 2))):
            tail.append(nn.Conv2d(num_chns, num_chns*4, kernel_size=3, padding=1))
            tail.append(nn.LeakyReLU(0.2, inplace=True))
            tail.append(nn.PixelShuffle(2))
        tail.append(nn.Conv2d(num_chns, out_chns, kernel_size=5, padding=2))
        self.tail = nn.Sequential(*tail)

    def forward(self, x):
        y = self.head(x)
        y = self.body(y)
        y = self.tail(y)
        return y

class SRCNNFSR(nn.Module):
    def __init__(self, in_chns, down_scale_factor=2, num_chns=64, depth=8, sf=4):
        super().__init__()
        self.sf = sf

        head = []
        in_chns_shuffle = in_chns * 4
        assert num_chns % 4 == 0
        for ii in range(int(math.log(down_scale_factor, 2))):
            head.append(nn.PixelUnshuffle(2))
            head.append(nn.Conv2d(in_chns_shuffle, num_chns, kernel_size=3, padding=1))
            if ii + 1 < int(math.log(down_scale_factor, 2)):
                head.append(nn.Conv2d(num_chns, num_chns//4, kernel_size=5, padding=2))
                head.append(nn.LeakyReLU(0.2, inplace=True))
                in_chns_shuffle = num_chns
        self.head = nn.Sequential(*head)

        body = []
        for _ in range(depth-1):
            body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2))
            body.append(nn.LeakyReLU(0.2, inplace=True))
        self.body = nn.Sequential(*body)

        tail = []
        for _ in range(int(math.log(down_scale_factor, 2))):
            tail.append(nn.Conv2d(num_chns, num_chns, kernel_size=3, padding=1))
            tail.append(nn.LeakyReLU(0.2, inplace=True))
            tail.append(nn.PixelShuffle(2))
            num_chns //= 4
        tail.append(nn.Conv2d(num_chns, in_chns, kernel_size=5, padding=2))
        self.tail = nn.Sequential(*tail)

    def forward(self, x):
        y = self.head(x)
        y = self.body(y)
        y = self.tail(y)
        return y