|
import torch.nn as nn |
|
from utils import ( |
|
ConvRep5, |
|
ConvRep3, |
|
ConvRepPoint, |
|
DropBlock, |
|
QuadraticConnectionUnit, |
|
QuadraticConnectionUnitS, |
|
) |
|
|
|
|
|
class SYELLENet(nn.Module): |
|
def __init__(self, channels, rep_scale=4): |
|
super(SYELLENet, self).__init__() |
|
self.channels = channels |
|
self.head = QuadraticConnectionUnit( |
|
nn.Sequential( |
|
ConvRep5(3, channels, rep_scale=rep_scale), |
|
nn.PReLU(channels), |
|
ConvRep3(channels, channels, rep_scale=rep_scale) |
|
), |
|
ConvRep5(3, channels, rep_scale=rep_scale), |
|
channels |
|
) |
|
self.body = QuadraticConnectionUnit( |
|
ConvRep3(channels, channels, rep_scale=rep_scale), |
|
ConvRepPoint(channels, channels, rep_scale=rep_scale), |
|
12 |
|
) |
|
self.att = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
ConvRepPoint(channels, channels, rep_scale=rep_scale), |
|
nn.PReLU(channels), |
|
ConvRepPoint(channels, channels, rep_scale=rep_scale), |
|
nn.Sigmoid() |
|
) |
|
self.tail = ConvRep3(channels, 3, rep_scale=rep_scale) |
|
|
|
self.tail_warm = ConvRep3(channels, 3, rep_scale=rep_scale) |
|
self.drop = DropBlock(3) |
|
|
|
def forward(self, x): |
|
x = self.head(x) |
|
x = self.body(x) |
|
x = self.att(x) * x |
|
return self.tail(x) |
|
|
|
def forward_warm(self, x): |
|
x = self.drop(x) |
|
x = self.head(x) |
|
x = self.body(x) |
|
return self.tail(x), self.tail_warm(x) |
|
|
|
def slim(self): |
|
net_slim = SYELLENetS(self.channels) |
|
weight_slim = net_slim.state_dict() |
|
for name, mod in self.named_modules(): |
|
if isinstance(mod, ConvRep3) or isinstance(mod, ConvRep5) or isinstance(mod, ConvRepPoint): |
|
if '%s.weight' % name in weight_slim: |
|
w, b = mod.slim() |
|
weight_slim['%s.weight' % name] = w |
|
weight_slim['%s.bias' % name] = b |
|
if 'block2' in name: |
|
weight_slim['%s.weight' % name] = weight_slim['%s.weight' % name] * 0.1 |
|
weight_slim['%s.bias' % name] = weight_slim['%s.bias' % name] * 0.1 |
|
elif isinstance(mod, QuadraticConnectionUnit): |
|
weight_slim['%s.bias' % name] = mod.bias |
|
elif isinstance(mod, nn.PReLU): |
|
weight_slim['%s.weight' % name] = mod.weight |
|
|
|
net_slim.load_state_dict(weight_slim) |
|
return net_slim |
|
|
|
|
|
class SYELLENetS(nn.Module): |
|
def __init__(self, channels): |
|
super(SYELLENetS, self).__init__() |
|
self.head = QuadraticConnectionUnitS( |
|
nn.Sequential( |
|
nn.Conv2d(3, channels, 5, 1, 2), |
|
nn.PReLU(channels), |
|
nn.Conv2d(channels, channels, 3, 1, 1) |
|
), |
|
nn.Conv2d(3, channels, 5, 1, 2), |
|
channels |
|
) |
|
self.body = QuadraticConnectionUnitS( |
|
nn.Conv2d(channels, channels, 3, 1, 1), |
|
nn.Conv2d(channels, channels, 1, ), |
|
12 |
|
) |
|
self.att = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(channels, channels, 1), |
|
nn.PReLU(channels), |
|
nn.Conv2d(channels, channels, 1), |
|
nn.Sigmoid() |
|
) |
|
self.tail = nn.Conv2d(channels, 3, 3, 1, 1) |
|
|
|
def forward(self, x): |
|
x = self.head(x) |
|
x = self.body(x) |
|
x = self.att(x) * x |
|
return self.tail(x) |
|
|