LicenseGAN / nets /esrgan.py
Egrt's picture
Update nets/esrgan.py
497bc49
import math
import torch
from torch import nn
class DenseResidualBlock(nn.Module):
"""
密集连接型残差网络
"""
def __init__(self, filters, res_scale=0.2):
super(DenseResidualBlock, self).__init__()
self.res_scale = res_scale
def block(in_features, non_linearity=True):
layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
if non_linearity:
layers += [nn.GELU()]
return nn.Sequential(*layers)
self.b1 = block(in_features=1 * filters)
self.b2 = block(in_features=2 * filters)
self.b3 = block(in_features=3 * filters)
self.b4 = block(in_features=4 * filters)
self.b5 = block(in_features=5 * filters, non_linearity=False)
self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
def forward(self, x):
inputs = x
for block in self.blocks:
out = block(inputs)
inputs = torch.cat([inputs, out], 1)
return out.mul(self.res_scale) + x
class ResidualInResidualDenseBlock(nn.Module):
def __init__(self, filters, res_scale=0.2):
super(ResidualInResidualDenseBlock, self).__init__()
self.res_scale = res_scale
self.dense_blocks = nn.Sequential(
DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
)
def forward(self, x):
return self.dense_blocks(x).mul(self.res_scale) + x
class UpsampleBLock(nn.Module):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.gelu = nn.GELU()
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.gelu(x)
return x
class Generator(nn.Module):
def __init__(self, scale_factor, channels=3, filters=64, num_res_blocks=6):
super(Generator, self).__init__()
upsample_block_num = int(math.log(scale_factor, 2))
# 第一个卷积层
self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
# 密集残差连接块
self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
# 第二个卷积层
self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
self.upsample = [UpsampleBLock(filters, 2) for _ in range(upsample_block_num)]
self.upsample = nn.Sequential(*self.upsample)
# 输出卷积层
self.conv3 = nn.Sequential(
nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
upsample = self.upsample(out)
out = self.conv3(upsample)
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.GELU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.GELU(),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.GELU(),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.GELU(),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.GELU(),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.GELU(),
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.GELU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, kernel_size=1),
nn.GELU(),
nn.Conv2d(1024, 1, kernel_size=1)
)
def forward(self, x):
batch_size = x.size(0)
return torch.sigmoid(self.net(x).view(batch_size))
if __name__ == "__main__":
from torchsummary import summary
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Generator(8).to(device)
summary(model, input_size=(3,12,24))