amosyou commited on
Commit
d813284
1 Parent(s): 8d1784b

add demo + inference files

Browse files
Files changed (3) hide show
  1. app.py +54 -0
  2. networks.py +366 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ import torch
5
+ from networks import define_G
6
+ from torchvision import transforms
7
+
8
+ REPO_ID = "Launchpad/ditto"
9
+ FILENAME = "model.pth"
10
+
11
+ # model_dict = torch.load("model.pth")
12
+ model_dict = torch.load(
13
+ hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
14
+ )
15
+ generator = define_G(input_nc=3, output_nc=3, ngf=64, netG="resnet_9blocks", norm="instance")
16
+ generator.load_state_dict(model_dict)
17
+ generator.eval()
18
+
19
+ # set up transforms for model
20
+ encode = transforms.Compose([
21
+ transforms.ToTensor(),
22
+ transforms.Resize((256, 256))
23
+ ])
24
+ transform = transforms.ToPILImage()
25
+
26
+ def generate_pokemon(pet_img):
27
+ # encode image
28
+ encoded_img = encode(pet_img)
29
+
30
+ # evaluate model on pet image
31
+ with torch.no_grad():
32
+ generated_img = generator(encoded_img)
33
+
34
+ # transform to PIL image
35
+ return transform(generated_img)
36
+
37
+ with gr.Blocks() as demo:
38
+ with gr.Row():
39
+ with gr.Column(scale=1):
40
+ gr.Image("https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/Ditto.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False)
41
+
42
+ with gr.Column(scale=3):
43
+ gr.Markdown("""Ditto is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Fall 2022) that transfers styles of Pokemon sprites onto pet images using GANs and contrastive learning.
44
+ <br/><br/>
45
+ **Model**: [ditto](https://huggingface.co/Launchpad/ditto)
46
+ <br/>
47
+ **Developed by**: Kiran Suresh, Annie Lee, Chloe Wong, Tony Xin, Sebastian Zhao
48
+ """
49
+ )
50
+ with gr.Row():
51
+ gr.Interface(generate_pokemon, gr.Image(), "image")
52
+
53
+ if __name__ == '__main__':
54
+ demo.launch()
networks.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+
7
+
8
+ ###############################################################################
9
+ # Helper Functions
10
+ ###############################################################################
11
+
12
+
13
+ class Identity(nn.Module):
14
+ def forward(self, x):
15
+ return x
16
+
17
+
18
+ def get_norm_layer(norm_type='instance'):
19
+ """Return a normalization layer
20
+ Parameters:
21
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
22
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
23
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
24
+ """
25
+ if norm_type == 'batch':
26
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
27
+ elif norm_type == 'instance':
28
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
29
+ elif norm_type == 'none':
30
+ def norm_layer(x):
31
+ return Identity()
32
+ else:
33
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
34
+ return norm_layer
35
+
36
+
37
+ def get_scheduler(optimizer, opt):
38
+ """Return a learning rate scheduler
39
+ Parameters:
40
+ optimizer -- the optimizer of the network
41
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
42
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
43
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
44
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
45
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
46
+ See https://pytorch.org/docs/stable/optim.html for more details.
47
+ """
48
+ if opt.lr_policy == 'linear':
49
+ def lambda_rule(epoch):
50
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
51
+ return lr_l
52
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
53
+ elif opt.lr_policy == 'step':
54
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
55
+ elif opt.lr_policy == 'plateau':
56
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
57
+ elif opt.lr_policy == 'cosine':
58
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
59
+ else:
60
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
61
+ return scheduler
62
+
63
+
64
+ def init_weights(net, init_type='normal', init_gain=0.02):
65
+ """Initialize network weights.
66
+ Parameters:
67
+ net (network) -- network to be initialized
68
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
69
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
70
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
71
+ work better for some applications. Feel free to try yourself.
72
+ """
73
+ def init_func(m): # define the initialization function
74
+ classname = m.__class__.__name__
75
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
76
+ if init_type == 'normal':
77
+ init.normal_(m.weight.data, 0.0, init_gain)
78
+ elif init_type == 'xavier':
79
+ init.xavier_normal_(m.weight.data, gain=init_gain)
80
+ elif init_type == 'kaiming':
81
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
82
+ elif init_type == 'orthogonal':
83
+ init.orthogonal_(m.weight.data, gain=init_gain)
84
+ else:
85
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
86
+ if hasattr(m, 'bias') and m.bias is not None:
87
+ init.constant_(m.bias.data, 0.0)
88
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
89
+ init.normal_(m.weight.data, 1.0, init_gain)
90
+ init.constant_(m.bias.data, 0.0)
91
+
92
+ print('initialize network with %s' % init_type)
93
+ net.apply(init_func) # apply the initialization function <init_func>
94
+
95
+
96
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
97
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
98
+ Parameters:
99
+ net (network) -- the network to be initialized
100
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
101
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
102
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
103
+ Return an initialized network.
104
+ """
105
+ if len(gpu_ids) > 0:
106
+ assert(torch.cuda.is_available())
107
+ net.to(gpu_ids[0])
108
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
109
+ init_weights(net, init_type, init_gain=init_gain)
110
+ return net
111
+
112
+
113
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
114
+ """Create a generator
115
+ Parameters:
116
+ input_nc (int) -- the number of channels in input images
117
+ output_nc (int) -- the number of channels in output images
118
+ ngf (int) -- the number of filters in the last conv layer
119
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
120
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
121
+ use_dropout (bool) -- if use dropout layers.
122
+ init_type (str) -- the name of our initialization method.
123
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
124
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
125
+ Returns a generator
126
+ Our current implementation provides two types of generators:
127
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
128
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
129
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
130
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
131
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
132
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
133
+ """
134
+ net = None
135
+ norm_layer = get_norm_layer(norm_type=norm)
136
+
137
+ if netG == 'resnet_9blocks':
138
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
139
+ elif netG == 'resnet_6blocks':
140
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
141
+ elif netG == 'unet_128':
142
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
143
+ elif netG == 'unet_256':
144
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
145
+ else:
146
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
147
+ return init_net(net, init_type, init_gain, gpu_ids)
148
+
149
+ ##############################################################################
150
+ # Classes
151
+ ##############################################################################
152
+
153
+ class ResnetGenerator(nn.Module):
154
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
155
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
156
+ """
157
+
158
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
159
+ """Construct a Resnet-based generator
160
+ Parameters:
161
+ input_nc (int) -- the number of channels in input images
162
+ output_nc (int) -- the number of channels in output images
163
+ ngf (int) -- the number of filters in the last conv layer
164
+ norm_layer -- normalization layer
165
+ use_dropout (bool) -- if use dropout layers
166
+ n_blocks (int) -- the number of ResNet blocks
167
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
168
+ """
169
+ assert(n_blocks >= 0)
170
+ super(ResnetGenerator, self).__init__()
171
+ if type(norm_layer) == functools.partial:
172
+ use_bias = norm_layer.func == nn.InstanceNorm2d
173
+ else:
174
+ use_bias = norm_layer == nn.InstanceNorm2d
175
+
176
+ model = [nn.ReflectionPad2d(3),
177
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
178
+ norm_layer(ngf),
179
+ nn.ReLU(True)]
180
+
181
+ n_downsampling = 2
182
+ for i in range(n_downsampling): # add downsampling layers
183
+ mult = 2 ** i
184
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
185
+ norm_layer(ngf * mult * 2),
186
+ nn.ReLU(True)]
187
+
188
+ mult = 2 ** n_downsampling
189
+ for i in range(n_blocks): # add ResNet blocks
190
+
191
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
192
+
193
+ for i in range(n_downsampling): # add upsampling layers
194
+ mult = 2 ** (n_downsampling - i)
195
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
196
+ kernel_size=3, stride=2,
197
+ padding=1, output_padding=1,
198
+ bias=use_bias),
199
+ norm_layer(int(ngf * mult / 2)),
200
+ nn.ReLU(True)]
201
+ model += [nn.ReflectionPad2d(3)]
202
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
203
+ model += [nn.Tanh()]
204
+
205
+ self.model = nn.Sequential(*model)
206
+
207
+ def forward(self, input):
208
+ """Standard forward"""
209
+ return self.model(input)
210
+
211
+
212
+ class ResnetBlock(nn.Module):
213
+ """Define a Resnet block"""
214
+
215
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
216
+ """Initialize the Resnet block
217
+ A resnet block is a conv block with skip connections
218
+ We construct a conv block with build_conv_block function,
219
+ and implement skip connections in <forward> function.
220
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
221
+ """
222
+ super(ResnetBlock, self).__init__()
223
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
224
+
225
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
226
+ """Construct a convolutional block.
227
+ Parameters:
228
+ dim (int) -- the number of channels in the conv layer.
229
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
230
+ norm_layer -- normalization layer
231
+ use_dropout (bool) -- if use dropout layers.
232
+ use_bias (bool) -- if the conv layer uses bias or not
233
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
234
+ """
235
+ conv_block = []
236
+ p = 0
237
+ if padding_type == 'reflect':
238
+ conv_block += [nn.ReflectionPad2d(1)]
239
+ elif padding_type == 'replicate':
240
+ conv_block += [nn.ReplicationPad2d(1)]
241
+ elif padding_type == 'zero':
242
+ p = 1
243
+ else:
244
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
245
+
246
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
247
+ if use_dropout:
248
+ conv_block += [nn.Dropout(0.5)]
249
+
250
+ p = 0
251
+ if padding_type == 'reflect':
252
+ conv_block += [nn.ReflectionPad2d(1)]
253
+ elif padding_type == 'replicate':
254
+ conv_block += [nn.ReplicationPad2d(1)]
255
+ elif padding_type == 'zero':
256
+ p = 1
257
+ else:
258
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
259
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
260
+
261
+ return nn.Sequential(*conv_block)
262
+
263
+ def forward(self, x):
264
+ """Forward function (with skip connections)"""
265
+ out = x + self.conv_block(x) # add skip connections
266
+ return out
267
+
268
+
269
+ class UnetGenerator(nn.Module):
270
+ """Create a Unet-based generator"""
271
+
272
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
273
+ """Construct a Unet generator
274
+ Parameters:
275
+ input_nc (int) -- the number of channels in input images
276
+ output_nc (int) -- the number of channels in output images
277
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
278
+ image of size 128x128 will become of size 1x1 # at the bottleneck
279
+ ngf (int) -- the number of filters in the last conv layer
280
+ norm_layer -- normalization layer
281
+ We construct the U-Net from the innermost layer to the outermost layer.
282
+ It is a recursive process.
283
+ """
284
+ super(UnetGenerator, self).__init__()
285
+ # construct unet structure
286
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
287
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
288
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
289
+ # gradually reduce the number of filters from ngf * 8 to ngf
290
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
291
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
292
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
293
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
294
+
295
+ def forward(self, input):
296
+ """Standard forward"""
297
+ return self.model(input)
298
+
299
+
300
+ class UnetSkipConnectionBlock(nn.Module):
301
+ """Defines the Unet submodule with skip connection.
302
+ X -------------------identity----------------------
303
+ |-- downsampling -- |submodule| -- upsampling --|
304
+ """
305
+
306
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
307
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
308
+ """Construct a Unet submodule with skip connections.
309
+ Parameters:
310
+ outer_nc (int) -- the number of filters in the outer conv layer
311
+ inner_nc (int) -- the number of filters in the inner conv layer
312
+ input_nc (int) -- the number of channels in input images/features
313
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
314
+ outermost (bool) -- if this module is the outermost module
315
+ innermost (bool) -- if this module is the innermost module
316
+ norm_layer -- normalization layer
317
+ use_dropout (bool) -- if use dropout layers.
318
+ """
319
+ super(UnetSkipConnectionBlock, self).__init__()
320
+ self.outermost = outermost
321
+ if type(norm_layer) == functools.partial:
322
+ use_bias = norm_layer.func == nn.InstanceNorm2d
323
+ else:
324
+ use_bias = norm_layer == nn.InstanceNorm2d
325
+ if input_nc is None:
326
+ input_nc = outer_nc
327
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
328
+ stride=2, padding=1, bias=use_bias)
329
+ downrelu = nn.LeakyReLU(0.2, True)
330
+ downnorm = norm_layer(inner_nc)
331
+ uprelu = nn.ReLU(True)
332
+ upnorm = norm_layer(outer_nc)
333
+
334
+ if outermost:
335
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
336
+ kernel_size=4, stride=2,
337
+ padding=1)
338
+ down = [downconv]
339
+ up = [uprelu, upconv, nn.Tanh()]
340
+ model = down + [submodule] + up
341
+ elif innermost:
342
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
343
+ kernel_size=4, stride=2,
344
+ padding=1, bias=use_bias)
345
+ down = [downrelu, downconv]
346
+ up = [uprelu, upconv, upnorm]
347
+ model = down + up
348
+ else:
349
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
350
+ kernel_size=4, stride=2,
351
+ padding=1, bias=use_bias)
352
+ down = [downrelu, downconv, downnorm]
353
+ up = [uprelu, upconv, upnorm]
354
+
355
+ if use_dropout:
356
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
357
+ else:
358
+ model = down + [submodule] + up
359
+
360
+ self.model = nn.Sequential(*model)
361
+
362
+ def forward(self, x):
363
+ if self.outermost:
364
+ return self.model(x)
365
+ else: # add skip connections
366
+ return torch.cat([x, self.model(x)], 1)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.0
2
+ torchvision==0.14.0
3
+ gradio==4.26.0