OneRestore / model /OneRestore.py
gy65896's picture
Upload 36 files
2940390 verified
raw
history blame
11.2 kB
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 20 16:14:37 2021
@author: Administrator
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torchvision import transforms
import torch, math
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numbers
from thop import profile
import numpy as np
import time
from torchvision import transforms
class OneRestore(nn.Module):
def __init__(self, channel = 32):
super(OneRestore,self).__init__()
self.norm = lambda x: (x-0.5)/0.5
self.denorm = lambda x: (x+1)/2
self.in_conv = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
self.encoder = encoder(channel)
self.middle = backbone(channel)
self.decoder = decoder(channel)
self.out_conv = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
def forward(self,x,embedding):
x_in = self.in_conv(self.norm(x))
x_l, x_m, x_s, x_ss = self.encoder(x_in, embedding)
x_mid = self.middle(x_ss, embedding)
x_out = self.decoder(x_mid, x_ss, x_s, x_m, x_l, embedding)
out = self.out_conv(x_out) + x
return self.denorm(out)
class encoder(nn.Module):
def __init__(self,channel):
super(encoder,self).__init__()
self.el = ResidualBlock(channel)#16
self.em = ResidualBlock(channel*2)#32
self.es = ResidualBlock(channel*4)#64
self.ess = ResidualBlock(channel*8)#128
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
self.conv_esstesss = nn.Conv2d(8*channel,16*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 256
def forward(self,x,embedding):
elout = self.el(x, embedding)#16
x_emin = self.conv_eltem(self.maxpool(elout))#32
emout = self.em(x_emin, embedding)
x_esin = self.conv_emtes(self.maxpool(emout))
esout = self.es(x_esin, embedding)
x_esin = self.conv_estess(self.maxpool(esout))
essout = self.ess(x_esin, embedding)#128
return elout, emout, esout, essout#,esssout
class backbone(nn.Module):
def __init__(self,channel):
super(backbone,self).__init__()
self.s1 = ResidualBlock(channel*8)#128
self.s2 = ResidualBlock(channel*8)#128
def forward(self,x,embedding):
share1 = self.s1(x, embedding)
share2 = self.s2(share1, embedding)
return share2
class decoder(nn.Module):
def __init__(self,channel):
super(decoder,self).__init__()
self.dss = ResidualBlock(channel*8)#128
self.ds = ResidualBlock(channel*4)#64
self.dm = ResidualBlock(channel*2)#32
self.dl = ResidualBlock(channel)#16
#self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
def _upsample(self,x,y):
_,_,H0,W0 = y.size()
return F.interpolate(x,size=(H0,W0),mode='bilinear')
def forward(self, x, x_ss, x_s, x_m, x_l, embedding):
dssout = self.dss(x + x_ss, embedding)
x_dsin = self.conv_dsstds(self._upsample(dssout, x_s))
dsout = self.ds(x_dsin + x_s, embedding)
x_dmin = self.conv_dstdm(self._upsample(dsout, x_m))
dmout = self.dm(x_dmin + x_m, embedding)
x_dlin = self.conv_dmtdl(self._upsample(dmout, x_l))
dlout = self.dl(x_dlin + x_l, embedding)
return dlout
class ResidualBlock(nn.Module): # Edge-oriented Residual Convolution Block 面向边缘的残差网络块 解决梯度消失的问题
def __init__(self, channel, norm=False):
super(ResidualBlock, self).__init__()
self.el = TransformerBlock(channel, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
def forward(self, x,embedding):
return self.el(x,embedding)
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
class Cross_Attention(nn.Module):
def __init__(self,
dim,
num_heads,
bias,
q_dim = 324):
super(Cross_Attention, self).__init__()
self.dim = dim
self.num_heads = num_heads
sqrt_q_dim = int(math.sqrt(q_dim))
self.resize = transforms.Resize([sqrt_q_dim, sqrt_q_dim])
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.q = nn.Linear(q_dim, q_dim, bias=bias)
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x, query):
b,c,h,w = x.shape
q = self.q(query)
k, v = self.kv_dwconv(self.kv(x)).chunk(2, dim=1)
k = self.resize(k)
q = repeat(q, 'b l -> b head c l', head=self.num_heads, c=self.dim//self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class Self_Attention(nn.Module):
def __init__(self,
dim,
num_heads,
bias):
super(Self_Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class FeedForward(nn.Module):
def __init__(self,
dim,
ffn_expansion_factor,
bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
groups=hidden_features * 2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class TransformerBlock(nn.Module):
def __init__(self,
dim,
num_heads=8,
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type='WithBias'):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.cross_attn = Cross_Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.self_attn = Self_Attention(dim, num_heads, bias)
self.norm3 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x, query):
x = x + self.cross_attn(self.norm1(x),query)
x = x + self.self_attn(self.norm2(x))
x = x + self.ffn(self.norm3(x))
return x
if __name__ == '__main__':
net = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
# x = torch.Tensor(np.random.random((2,3,256,256))).to("cuda" if torch.cuda.is_available() else "cpu")
# query = torch.Tensor(np.random.random((2, 324))).to("cuda" if torch.cuda.is_available() else "cpu")
# out = net(x, query)
# print(out.shape)
input = torch.randn(1, 3, 512, 512).to("cuda" if torch.cuda.is_available() else "cpu")
query = torch.Tensor(np.random.random((1, 324))).to("cuda" if torch.cuda.is_available() else "cpu")
macs, _ = profile(net, inputs=(input, query))
total = sum([param.nelement() for param in net.parameters()])
print('Macs = ' + str(macs/1000**3) + 'G')
print('Params = ' + str(total/1e6) + 'M')
from fvcore.nn import FlopCountAnalysis, parameter_count_table
flops = FlopCountAnalysis(net, (input, query))
print("FLOPs", flops.total()/1000**3)