ToonModels / model.py
akshitapps's picture
Upload 3 files
11be304
raw
history blame
No virus
1.74 kB
from networks import ResnetBlock
import functools
import torch
import torch.nn as nn
# class GenerativeModel():
# def __init__(self):
# self.model = networks.define_G(3, 3,64, "global", 4, 9, 1,3, "instance", gpu_ids=[0])
class GlobalGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_downsampling=4, n_blocks=9, norm_layer=functools.partial(nn.InstanceNorm2d, affine=False),
padding_type='reflect'):
assert(n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)