MoonLite / lle.py
KhadgaA's picture
moonknight
dc604d9
import torch.nn as nn
from utils import (
ConvRep5,
ConvRep3,
ConvRepPoint,
DropBlock,
QuadraticConnectionUnit,
QuadraticConnectionUnitS,
)
class SYELLENet(nn.Module):
def __init__(self, channels, rep_scale=4):
super(SYELLENet, self).__init__()
self.channels = channels
self.head = QuadraticConnectionUnit(
nn.Sequential(
ConvRep5(3, channels, rep_scale=rep_scale),
nn.PReLU(channels),
ConvRep3(channels, channels, rep_scale=rep_scale)
),
ConvRep5(3, channels, rep_scale=rep_scale),
channels
)
self.body = QuadraticConnectionUnit(
ConvRep3(channels, channels, rep_scale=rep_scale),
ConvRepPoint(channels, channels, rep_scale=rep_scale),
12
)
self.att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvRepPoint(channels, channels, rep_scale=rep_scale),
nn.PReLU(channels),
ConvRepPoint(channels, channels, rep_scale=rep_scale),
nn.Sigmoid()
)
self.tail = ConvRep3(channels, 3, rep_scale=rep_scale)
self.tail_warm = ConvRep3(channels, 3, rep_scale=rep_scale)
self.drop = DropBlock(3)
def forward(self, x):
x = self.head(x)
x = self.body(x)
x = self.att(x) * x
return self.tail(x)
def forward_warm(self, x):
x = self.drop(x)
x = self.head(x)
x = self.body(x)
return self.tail(x), self.tail_warm(x)
def slim(self):
net_slim = SYELLENetS(self.channels)
weight_slim = net_slim.state_dict()
for name, mod in self.named_modules():
if isinstance(mod, ConvRep3) or isinstance(mod, ConvRep5) or isinstance(mod, ConvRepPoint):
if '%s.weight' % name in weight_slim:
w, b = mod.slim()
weight_slim['%s.weight' % name] = w
weight_slim['%s.bias' % name] = b
if 'block2' in name:
weight_slim['%s.weight' % name] = weight_slim['%s.weight' % name] * 0.1
weight_slim['%s.bias' % name] = weight_slim['%s.bias' % name] * 0.1
elif isinstance(mod, QuadraticConnectionUnit):
weight_slim['%s.bias' % name] = mod.bias
elif isinstance(mod, nn.PReLU):
weight_slim['%s.weight' % name] = mod.weight
net_slim.load_state_dict(weight_slim)
return net_slim
class SYELLENetS(nn.Module):
def __init__(self, channels):
super(SYELLENetS, self).__init__()
self.head = QuadraticConnectionUnitS(
nn.Sequential(
nn.Conv2d(3, channels, 5, 1, 2),
nn.PReLU(channels),
nn.Conv2d(channels, channels, 3, 1, 1)
),
nn.Conv2d(3, channels, 5, 1, 2),
channels
)
self.body = QuadraticConnectionUnitS(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.Conv2d(channels, channels, 1, ),
12
)
self.att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels, 1),
nn.PReLU(channels),
nn.Conv2d(channels, channels, 1),
nn.Sigmoid()
)
self.tail = nn.Conv2d(channels, 3, 3, 1, 1)
def forward(self, x):
x = self.head(x)
x = self.body(x)
x = self.att(x) * x
return self.tail(x)