Spaces:
Runtime error
Runtime error
from torch import nn | |
from torch.nn import BatchNorm2d, PReLU, Sequential, Module | |
from torchvision.models import resnet34 | |
from models.hyperstyle.hypernetworks.refinement_blocks import HyperRefinementBlock, RefinementBlock, RefinementBlockSeparable | |
from models.hyperstyle.hypernetworks.shared_weights_hypernet import SharedWeightsHypernet | |
class SharedWeightsHyperNetResNet(Module): | |
def __init__(self, opts): | |
super(SharedWeightsHyperNetResNet, self).__init__() | |
self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
self.bn1 = BatchNorm2d(64) | |
self.relu = PReLU(64) | |
resnet_basenet = resnet34(pretrained=True) | |
blocks = [ | |
resnet_basenet.layer1, | |
resnet_basenet.layer2, | |
resnet_basenet.layer3, | |
resnet_basenet.layer4 | |
] | |
modules = [] | |
for block in blocks: | |
for bottleneck in block: | |
modules.append(bottleneck) | |
self.body = Sequential(*modules) | |
if len(opts.layers_to_tune) == 0: | |
self.layers_to_tune = list(range(opts.n_hypernet_outputs)) | |
else: | |
self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')] | |
self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12] | |
self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None, device=opts.device) | |
self.refinement_blocks = nn.ModuleList() | |
self.n_outputs = opts.n_hypernet_outputs | |
for layer_idx in range(self.n_outputs): | |
if layer_idx in self.layers_to_tune: | |
if layer_idx in self.shared_layers: | |
refinement_block = HyperRefinementBlock(self.shared_weight_hypernet, n_channels=512, inner_c=128) | |
else: | |
refinement_block = RefinementBlock(layer_idx, opts, n_channels=512, inner_c=256) | |
else: | |
refinement_block = None | |
self.refinement_blocks.append(refinement_block) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.body(x) | |
weight_deltas = [] | |
for j in range(self.n_outputs): | |
if self.refinement_blocks[j] is not None: | |
delta = self.refinement_blocks[j](x) | |
else: | |
delta = None | |
weight_deltas.append(delta) | |
return weight_deltas | |
class SharedWeightsHyperNetResNetSeparable(Module): | |
def __init__(self, opts): | |
super(SharedWeightsHyperNetResNetSeparable, self).__init__() | |
self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
self.bn1 = BatchNorm2d(64) | |
self.relu = PReLU(64) | |
resnet_basenet = resnet34(pretrained=True) | |
blocks = [ | |
resnet_basenet.layer1, | |
resnet_basenet.layer2, | |
resnet_basenet.layer3, | |
resnet_basenet.layer4 | |
] | |
modules = [] | |
for block in blocks: | |
for bottleneck in block: | |
modules.append(bottleneck) | |
self.body = Sequential(*modules) | |
if len(opts.layers_to_tune) == 0: | |
self.layers_to_tune = list(range(opts.n_hypernet_outputs)) | |
else: | |
self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')] | |
self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12] | |
self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None, device=opts.device) | |
self.refinement_blocks = nn.ModuleList() | |
self.n_outputs = opts.n_hypernet_outputs | |
for layer_idx in range(self.n_outputs): | |
if layer_idx in self.layers_to_tune: | |
if layer_idx in self.shared_layers: | |
refinement_block = HyperRefinementBlock(self.shared_weight_hypernet, n_channels=512, inner_c=128) | |
else: | |
refinement_block = RefinementBlockSeparable(layer_idx, opts, n_channels=512, inner_c=256) | |
else: | |
refinement_block = None | |
self.refinement_blocks.append(refinement_block) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.body(x) | |
weight_deltas = [] | |
for j in range(self.n_outputs): | |
if self.refinement_blocks[j] is not None: | |
delta = self.refinement_blocks[j](x) | |
else: | |
delta = None | |
weight_deltas.append(delta) | |
return weight_deltas | |