File size: 6,144 Bytes
3ef0208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from model import common
from model import attention
import torch
from lambda_networks import LambdaLayer
import torch.nn as nn
import torch.cuda.amp as amp
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
# h = (1-z) * h + z * q
# return h
return (1-z) * h + z * q
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
def make_model(args, parent=False):
return RAFTNET(args)
class RAFTNET(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(RAFTNET, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# msa = attention.PyramidAttention()
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# perhaps a shallow network here?
for i in range(2):
m_head.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale))
# convert feature to image, shared
m_tail=[]
for i in range(2):
m_tail.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale))
m_tail.append(conv(n_feats, args.n_colors, kernel_size))
# middle recurrent part
# layer = LambdaLayer(
# dim = n_feats,
# dim_out = n_feats,
# r = 23, # the receptive field for relative positional encoding (23 x 23)
# dim_k = 16,
# heads = 4,
# dim_u = 4,
# norm=args.normalization
# )
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale
) for _ in range(n_resblocks//2)
]
# m_body.append(layer)
for i in range(n_resblocks//2):
m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale))
m_body.append(conv(n_feats, n_feats, kernel_size))
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.hidden_encoder=nn.Sequential(
common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale),
common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale),
common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)
)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
# self.gru = ConvGRU(hidden_dim=64,input_dim=64)
self.recurrence = args.recurrence
self.detach = args.detach
# self.step_detach = args.step_detach
self.amp = args.amp
def forward(self, x):
with amp.autocast(self.amp):
x=(x-0.5)/0.5
x = self.head(x)
hidden = self.hidden_encoder(x)
output_lst=[None]*self.recurrence
for i in range(1):
res=self.body(hidden)
gru_out=res+hidden
output=self.tail(gru_out)
output_lst[i]=output*0.5+0.5
return output_lst
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
|