add demo + inference files
Browse files- app.py +54 -0
- networks.py +366 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from networks import define_G
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
REPO_ID = "Launchpad/ditto"
|
9 |
+
FILENAME = "model.pth"
|
10 |
+
|
11 |
+
# model_dict = torch.load("model.pth")
|
12 |
+
model_dict = torch.load(
|
13 |
+
hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
14 |
+
)
|
15 |
+
generator = define_G(input_nc=3, output_nc=3, ngf=64, netG="resnet_9blocks", norm="instance")
|
16 |
+
generator.load_state_dict(model_dict)
|
17 |
+
generator.eval()
|
18 |
+
|
19 |
+
# set up transforms for model
|
20 |
+
encode = transforms.Compose([
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Resize((256, 256))
|
23 |
+
])
|
24 |
+
transform = transforms.ToPILImage()
|
25 |
+
|
26 |
+
def generate_pokemon(pet_img):
|
27 |
+
# encode image
|
28 |
+
encoded_img = encode(pet_img)
|
29 |
+
|
30 |
+
# evaluate model on pet image
|
31 |
+
with torch.no_grad():
|
32 |
+
generated_img = generator(encoded_img)
|
33 |
+
|
34 |
+
# transform to PIL image
|
35 |
+
return transform(generated_img)
|
36 |
+
|
37 |
+
with gr.Blocks() as demo:
|
38 |
+
with gr.Row():
|
39 |
+
with gr.Column(scale=1):
|
40 |
+
gr.Image("https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/Ditto.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False)
|
41 |
+
|
42 |
+
with gr.Column(scale=3):
|
43 |
+
gr.Markdown("""Ditto is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Fall 2022) that transfers styles of Pokemon sprites onto pet images using GANs and contrastive learning.
|
44 |
+
<br/><br/>
|
45 |
+
**Model**: [ditto](https://huggingface.co/Launchpad/ditto)
|
46 |
+
<br/>
|
47 |
+
**Developed by**: Kiran Suresh, Annie Lee, Chloe Wong, Tony Xin, Sebastian Zhao
|
48 |
+
"""
|
49 |
+
)
|
50 |
+
with gr.Row():
|
51 |
+
gr.Interface(generate_pokemon, gr.Image(), "image")
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
demo.launch()
|
networks.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
|
7 |
+
|
8 |
+
###############################################################################
|
9 |
+
# Helper Functions
|
10 |
+
###############################################################################
|
11 |
+
|
12 |
+
|
13 |
+
class Identity(nn.Module):
|
14 |
+
def forward(self, x):
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def get_norm_layer(norm_type='instance'):
|
19 |
+
"""Return a normalization layer
|
20 |
+
Parameters:
|
21 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
22 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
23 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
24 |
+
"""
|
25 |
+
if norm_type == 'batch':
|
26 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
27 |
+
elif norm_type == 'instance':
|
28 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
29 |
+
elif norm_type == 'none':
|
30 |
+
def norm_layer(x):
|
31 |
+
return Identity()
|
32 |
+
else:
|
33 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
34 |
+
return norm_layer
|
35 |
+
|
36 |
+
|
37 |
+
def get_scheduler(optimizer, opt):
|
38 |
+
"""Return a learning rate scheduler
|
39 |
+
Parameters:
|
40 |
+
optimizer -- the optimizer of the network
|
41 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
42 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
43 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
44 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
45 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
46 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
47 |
+
"""
|
48 |
+
if opt.lr_policy == 'linear':
|
49 |
+
def lambda_rule(epoch):
|
50 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
51 |
+
return lr_l
|
52 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
53 |
+
elif opt.lr_policy == 'step':
|
54 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
55 |
+
elif opt.lr_policy == 'plateau':
|
56 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
57 |
+
elif opt.lr_policy == 'cosine':
|
58 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
59 |
+
else:
|
60 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
61 |
+
return scheduler
|
62 |
+
|
63 |
+
|
64 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
65 |
+
"""Initialize network weights.
|
66 |
+
Parameters:
|
67 |
+
net (network) -- network to be initialized
|
68 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
69 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
70 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
71 |
+
work better for some applications. Feel free to try yourself.
|
72 |
+
"""
|
73 |
+
def init_func(m): # define the initialization function
|
74 |
+
classname = m.__class__.__name__
|
75 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
76 |
+
if init_type == 'normal':
|
77 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
78 |
+
elif init_type == 'xavier':
|
79 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
80 |
+
elif init_type == 'kaiming':
|
81 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
82 |
+
elif init_type == 'orthogonal':
|
83 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
84 |
+
else:
|
85 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
86 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
87 |
+
init.constant_(m.bias.data, 0.0)
|
88 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
89 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
90 |
+
init.constant_(m.bias.data, 0.0)
|
91 |
+
|
92 |
+
print('initialize network with %s' % init_type)
|
93 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
94 |
+
|
95 |
+
|
96 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
97 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
98 |
+
Parameters:
|
99 |
+
net (network) -- the network to be initialized
|
100 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
101 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
102 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
103 |
+
Return an initialized network.
|
104 |
+
"""
|
105 |
+
if len(gpu_ids) > 0:
|
106 |
+
assert(torch.cuda.is_available())
|
107 |
+
net.to(gpu_ids[0])
|
108 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
109 |
+
init_weights(net, init_type, init_gain=init_gain)
|
110 |
+
return net
|
111 |
+
|
112 |
+
|
113 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
114 |
+
"""Create a generator
|
115 |
+
Parameters:
|
116 |
+
input_nc (int) -- the number of channels in input images
|
117 |
+
output_nc (int) -- the number of channels in output images
|
118 |
+
ngf (int) -- the number of filters in the last conv layer
|
119 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
120 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
121 |
+
use_dropout (bool) -- if use dropout layers.
|
122 |
+
init_type (str) -- the name of our initialization method.
|
123 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
124 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
125 |
+
Returns a generator
|
126 |
+
Our current implementation provides two types of generators:
|
127 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
128 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
129 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
130 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
131 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
132 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
133 |
+
"""
|
134 |
+
net = None
|
135 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
136 |
+
|
137 |
+
if netG == 'resnet_9blocks':
|
138 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
|
139 |
+
elif netG == 'resnet_6blocks':
|
140 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
|
141 |
+
elif netG == 'unet_128':
|
142 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
143 |
+
elif netG == 'unet_256':
|
144 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
145 |
+
else:
|
146 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
147 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
148 |
+
|
149 |
+
##############################################################################
|
150 |
+
# Classes
|
151 |
+
##############################################################################
|
152 |
+
|
153 |
+
class ResnetGenerator(nn.Module):
|
154 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
155 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
159 |
+
"""Construct a Resnet-based generator
|
160 |
+
Parameters:
|
161 |
+
input_nc (int) -- the number of channels in input images
|
162 |
+
output_nc (int) -- the number of channels in output images
|
163 |
+
ngf (int) -- the number of filters in the last conv layer
|
164 |
+
norm_layer -- normalization layer
|
165 |
+
use_dropout (bool) -- if use dropout layers
|
166 |
+
n_blocks (int) -- the number of ResNet blocks
|
167 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
168 |
+
"""
|
169 |
+
assert(n_blocks >= 0)
|
170 |
+
super(ResnetGenerator, self).__init__()
|
171 |
+
if type(norm_layer) == functools.partial:
|
172 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
173 |
+
else:
|
174 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
175 |
+
|
176 |
+
model = [nn.ReflectionPad2d(3),
|
177 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
178 |
+
norm_layer(ngf),
|
179 |
+
nn.ReLU(True)]
|
180 |
+
|
181 |
+
n_downsampling = 2
|
182 |
+
for i in range(n_downsampling): # add downsampling layers
|
183 |
+
mult = 2 ** i
|
184 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
185 |
+
norm_layer(ngf * mult * 2),
|
186 |
+
nn.ReLU(True)]
|
187 |
+
|
188 |
+
mult = 2 ** n_downsampling
|
189 |
+
for i in range(n_blocks): # add ResNet blocks
|
190 |
+
|
191 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
192 |
+
|
193 |
+
for i in range(n_downsampling): # add upsampling layers
|
194 |
+
mult = 2 ** (n_downsampling - i)
|
195 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
196 |
+
kernel_size=3, stride=2,
|
197 |
+
padding=1, output_padding=1,
|
198 |
+
bias=use_bias),
|
199 |
+
norm_layer(int(ngf * mult / 2)),
|
200 |
+
nn.ReLU(True)]
|
201 |
+
model += [nn.ReflectionPad2d(3)]
|
202 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
203 |
+
model += [nn.Tanh()]
|
204 |
+
|
205 |
+
self.model = nn.Sequential(*model)
|
206 |
+
|
207 |
+
def forward(self, input):
|
208 |
+
"""Standard forward"""
|
209 |
+
return self.model(input)
|
210 |
+
|
211 |
+
|
212 |
+
class ResnetBlock(nn.Module):
|
213 |
+
"""Define a Resnet block"""
|
214 |
+
|
215 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
216 |
+
"""Initialize the Resnet block
|
217 |
+
A resnet block is a conv block with skip connections
|
218 |
+
We construct a conv block with build_conv_block function,
|
219 |
+
and implement skip connections in <forward> function.
|
220 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
221 |
+
"""
|
222 |
+
super(ResnetBlock, self).__init__()
|
223 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
224 |
+
|
225 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
226 |
+
"""Construct a convolutional block.
|
227 |
+
Parameters:
|
228 |
+
dim (int) -- the number of channels in the conv layer.
|
229 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
230 |
+
norm_layer -- normalization layer
|
231 |
+
use_dropout (bool) -- if use dropout layers.
|
232 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
233 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
234 |
+
"""
|
235 |
+
conv_block = []
|
236 |
+
p = 0
|
237 |
+
if padding_type == 'reflect':
|
238 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
239 |
+
elif padding_type == 'replicate':
|
240 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
241 |
+
elif padding_type == 'zero':
|
242 |
+
p = 1
|
243 |
+
else:
|
244 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
245 |
+
|
246 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
247 |
+
if use_dropout:
|
248 |
+
conv_block += [nn.Dropout(0.5)]
|
249 |
+
|
250 |
+
p = 0
|
251 |
+
if padding_type == 'reflect':
|
252 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
253 |
+
elif padding_type == 'replicate':
|
254 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
255 |
+
elif padding_type == 'zero':
|
256 |
+
p = 1
|
257 |
+
else:
|
258 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
259 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
260 |
+
|
261 |
+
return nn.Sequential(*conv_block)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward function (with skip connections)"""
|
265 |
+
out = x + self.conv_block(x) # add skip connections
|
266 |
+
return out
|
267 |
+
|
268 |
+
|
269 |
+
class UnetGenerator(nn.Module):
|
270 |
+
"""Create a Unet-based generator"""
|
271 |
+
|
272 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
273 |
+
"""Construct a Unet generator
|
274 |
+
Parameters:
|
275 |
+
input_nc (int) -- the number of channels in input images
|
276 |
+
output_nc (int) -- the number of channels in output images
|
277 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
278 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
279 |
+
ngf (int) -- the number of filters in the last conv layer
|
280 |
+
norm_layer -- normalization layer
|
281 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
282 |
+
It is a recursive process.
|
283 |
+
"""
|
284 |
+
super(UnetGenerator, self).__init__()
|
285 |
+
# construct unet structure
|
286 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
287 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
288 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
289 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
290 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
291 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
292 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
293 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
294 |
+
|
295 |
+
def forward(self, input):
|
296 |
+
"""Standard forward"""
|
297 |
+
return self.model(input)
|
298 |
+
|
299 |
+
|
300 |
+
class UnetSkipConnectionBlock(nn.Module):
|
301 |
+
"""Defines the Unet submodule with skip connection.
|
302 |
+
X -------------------identity----------------------
|
303 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
307 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
308 |
+
"""Construct a Unet submodule with skip connections.
|
309 |
+
Parameters:
|
310 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
311 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
312 |
+
input_nc (int) -- the number of channels in input images/features
|
313 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
314 |
+
outermost (bool) -- if this module is the outermost module
|
315 |
+
innermost (bool) -- if this module is the innermost module
|
316 |
+
norm_layer -- normalization layer
|
317 |
+
use_dropout (bool) -- if use dropout layers.
|
318 |
+
"""
|
319 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
320 |
+
self.outermost = outermost
|
321 |
+
if type(norm_layer) == functools.partial:
|
322 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
323 |
+
else:
|
324 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
325 |
+
if input_nc is None:
|
326 |
+
input_nc = outer_nc
|
327 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
328 |
+
stride=2, padding=1, bias=use_bias)
|
329 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
330 |
+
downnorm = norm_layer(inner_nc)
|
331 |
+
uprelu = nn.ReLU(True)
|
332 |
+
upnorm = norm_layer(outer_nc)
|
333 |
+
|
334 |
+
if outermost:
|
335 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
336 |
+
kernel_size=4, stride=2,
|
337 |
+
padding=1)
|
338 |
+
down = [downconv]
|
339 |
+
up = [uprelu, upconv, nn.Tanh()]
|
340 |
+
model = down + [submodule] + up
|
341 |
+
elif innermost:
|
342 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
343 |
+
kernel_size=4, stride=2,
|
344 |
+
padding=1, bias=use_bias)
|
345 |
+
down = [downrelu, downconv]
|
346 |
+
up = [uprelu, upconv, upnorm]
|
347 |
+
model = down + up
|
348 |
+
else:
|
349 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
350 |
+
kernel_size=4, stride=2,
|
351 |
+
padding=1, bias=use_bias)
|
352 |
+
down = [downrelu, downconv, downnorm]
|
353 |
+
up = [uprelu, upconv, upnorm]
|
354 |
+
|
355 |
+
if use_dropout:
|
356 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
357 |
+
else:
|
358 |
+
model = down + [submodule] + up
|
359 |
+
|
360 |
+
self.model = nn.Sequential(*model)
|
361 |
+
|
362 |
+
def forward(self, x):
|
363 |
+
if self.outermost:
|
364 |
+
return self.model(x)
|
365 |
+
else: # add skip connections
|
366 |
+
return torch.cat([x, self.model(x)], 1)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.0
|
2 |
+
torchvision==0.14.0
|
3 |
+
gradio==4.26.0
|