ldkong commited on
Commit
3c3a705
β€’
1 Parent(s): 082d673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -310
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
 
3
- import argparse
4
  import cv2
5
  import imageio
6
  import math
@@ -29,7 +28,7 @@ class RelationModuleMultiScale(torch.nn.Module):
29
  self.relations_scales.append(relations_scale)
30
  self.subsample_scales.append(min(self.subsample_num, len(relations_scale)))
31
  self.num_frames = num_frames
32
- self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist
33
  for i in range(len(self.scales)):
34
  scale = self.scales[i]
35
  fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU())
@@ -60,31 +59,6 @@ class RelationModuleMultiScale(torch.nn.Module):
60
  return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
61
 
62
 
63
- parser = argparse.ArgumentParser()
64
- parser.add_argument('--dataset', default='Sprite', help='datasets')
65
- parser.add_argument('--data_root', default='dataset', help='root directory for data')
66
- parser.add_argument('--num_class', type=int, default=15, help='the number of class for jester dataset')
67
- parser.add_argument('--input_type', default='image', choices=['feature', 'image'], help='the type of input')
68
- parser.add_argument('--src', default='domain_1', help='source domain')
69
- parser.add_argument('--tar', default='domain_2', help='target domain')
70
- parser.add_argument('--num_segments', type=int, default=8, help='the number of frame segment')
71
- parser.add_argument('--backbone', type=str, default="dcgan", choices=['dcgan', 'resnet101', 'I3Dpretrain','I3Dfinetune'], help='backbone')
72
- parser.add_argument('--channels', default=3, type=int, help='input channels for image inputs')
73
- parser.add_argument('--add_fc', default=1, type=int, metavar='M', help='number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2)')
74
- parser.add_argument('--fc_dim', type=int, default=1024, help='dimension of added fc')
75
- parser.add_argument('--frame_aggregation', type=str, default='trn', choices=[ 'rnn', 'trn'], help='aggregation of frame features (none if baseline_type is not video)')
76
- parser.add_argument('--dropout_rate', default=0.5, type=float, help='dropout ratio for frame-level feature (default: 0.5)')
77
- parser.add_argument('--f_dim', type=int, default=512, help='dim of f')
78
- parser.add_argument('--z_dim', type=int, default=512, help='dimensionality of z_t')
79
- parser.add_argument('--f_rnn_layers', type=int, default=1, help='number of layers (content lstm)')
80
- parser.add_argument('--use_bn', type=str, default='none', choices=['none', 'AdaBN', 'AutoDIAL'], help='normalization-based methods')
81
- parser.add_argument('--prior_sample', type=str, default='random', choices=['random', 'post'], help='how to sample prior')
82
- parser.add_argument('--batch_size', default=128, type=int, help='-batch size')
83
- parser.add_argument('--use_attn', type=str, default='TransAttn', choices=['none', 'TransAttn', 'general'], help='attention-mechanism')
84
- parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads')
85
- opt = parser.parse_args(args=[])
86
-
87
-
88
  class GradReverse(Function):
89
  @staticmethod
90
  def forward(ctx, x, beta):
@@ -99,157 +73,70 @@ class GradReverse(Function):
99
 
100
  class TransferVAE_Video(nn.Module):
101
 
102
- def __init__(self, opt):
103
  super(TransferVAE_Video, self).__init__()
104
- self.f_dim = opt.f_dim
105
- self.z_dim = opt.z_dim
106
- self.fc_dim = opt.fc_dim
107
- self.channels = opt.channels
108
- self.input_type = opt.input_type
109
- self.frames = opt.num_segments
110
- self.use_bn = opt.use_bn
111
- self.frame_aggregation = opt.frame_aggregation
112
- self.batch_size = opt.batch_size
113
- self.use_attn = opt.use_attn
114
- self.dropout_rate = opt.dropout_rate
115
- self.num_class = opt.num_class
116
- self.prior_sample = opt.prior_sample
117
 
118
- if self.input_type == 'image':
119
- import dcgan_64
120
- self.encoder = dcgan_64.encoder(self.fc_dim, self.channels)
121
- self.decoder = dcgan_64.decoder_woSkip(self.z_dim + self.f_dim, self.channels)
122
- self.fc_output_dim = self.fc_dim
123
- elif self.input_type == 'feature':
124
- if opt.backbone == 'resnet101':
125
- model_backnone = getattr(torchvision.models, opt.backbone)(True) # model_test is only used for getting the dim #
126
- self.input_dim = model_backnone.fc.in_features
127
- elif opt.backbone == 'I3Dpretrain':
128
- self.input_dim = 2048
129
- elif opt.backbone == 'I3Dfinetune':
130
- self.input_dim = 2048
131
- self.add_fc = opt.add_fc
132
- self.enc_fc_layer1 = nn.Linear(self.input_dim, self.fc_dim)
133
- self.dec_fc_layer1 = nn.Linear(self.fc_dim, self.input_dim)
134
- self.fc_output_dim = self.fc_dim
135
-
136
- if self.use_bn == 'shared':
137
- self.bn_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim)
138
- self.bn_dec_layer1 = nn.BatchNorm1d(self.input_dim)
139
- elif self.use_bn == 'separated':
140
- self.bn_S_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim)
141
- self.bn_T_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim)
142
- self.bn_S_dec_layer1 = nn.BatchNorm1d(self.input_dim)
143
- self.bn_T_dec_layer1 = nn.BatchNorm1d(self.input_dim)
144
-
145
- if self.add_fc > 1:
146
- self.enc_fc_layer2 = nn.Linear(self.fc_dim, self.fc_dim)
147
- self.dec_fc_layer2 = nn.Linear(self.fc_dim, self.fc_dim)
148
- self.fc_output_dim = self.fc_dim
149
- ## use batchnormalization or not (if yes whether the source and target share the same batchnormalization)
150
- if self.use_bn == 'shared':
151
- self.bn_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim)
152
- self.bn_dec_layer2 = nn.BatchNorm1d(self.fc_dim)
153
- elif self.use_bn == 'separated':
154
- self.bn_S_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim)
155
- self.bn_T_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim)
156
- self.bn_S_dec_layer2 = nn.BatchNorm1d(self.fc_dim)
157
- self.bn_T_dec_layer2 = nn.BatchNorm1d(self.fc_dim)
158
-
159
- if self.add_fc > 2:
160
- self.enc_fc_layer3 = nn.Linear(self.fc_dim, self.fc_dim)
161
- self.dec_fc_layer3 = nn.Linear(self.fc_dim, self.fc_dim)
162
- self.fc_output_dim = self.fc_dim
163
- ## use batchnormalization or not (if yes whether the source and target share the same batchnormalization)
164
- if self.use_bn == 'shared':
165
- self.bn_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim)
166
- self.bn_dec_layer3 = nn.BatchNorm1d(self.fc_dim)
167
- elif self.use_bn == 'separated':
168
- self.bn_S_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim)
169
- self.bn_T_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim)
170
- self.bn_S_dec_layer3 = nn.BatchNorm1d(self.fc_dim)
171
- self.bn_T_dec_layer3 = nn.BatchNorm1d(self.fc_dim)
172
-
173
- self.z_2_out = nn.Linear(self.z_dim + self.f_dim, self.fc_output_dim)
174
 
175
-
176
- ## nonlinearity and dropout
177
  self.relu = nn.LeakyReLU(0.1)
178
  self.dropout_f = nn.Dropout(p=self.dropout_rate)
179
  self.dropout_v = nn.Dropout(p=self.dropout_rate)
180
- # -------------------------------
181
 
182
- ## Disentangle strcuture
183
- # -------------------------------
184
- #self.hidden_dim = opt.rnn_size
185
- self.hidden_dim = opt.z_dim
186
- self.f_rnn_layers = opt.f_rnn_layers
187
 
188
- # Prior of content is a uniform Gaussian and prior of the dynamics is an LSTM
189
  self.z_prior_lstm_ly1 = nn.LSTMCell(self.z_dim, self.hidden_dim)
190
  self.z_prior_lstm_ly2 = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
191
 
192
  self.z_prior_mean = nn.Linear(self.hidden_dim, self.z_dim)
193
  self.z_prior_logvar = nn.Linear(self.hidden_dim, self.z_dim)
194
 
195
- # POSTERIOR DISTRIBUTION NETWORKS
196
- # content and motion features share one lstm
197
  self.z_lstm = nn.LSTM(self.fc_output_dim, self.hidden_dim, self.f_rnn_layers, bidirectional=True, batch_first=True)
198
  self.f_mean = nn.Linear(self.hidden_dim * 2, self.f_dim)
199
  self.f_logvar = nn.Linear(self.hidden_dim * 2, self.f_dim)
200
 
201
  self.z_rnn = nn.RNN(self.hidden_dim * 2, self.hidden_dim, batch_first=True)
202
- # Each timestep is for each z so no reshaping and feature mixing
203
  self.z_mean = nn.Linear(self.hidden_dim, self.z_dim)
204
  self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim)
205
- # -------------------------------
206
 
207
- ## z_t constraints
208
- # -------------------------------
209
- ## adversarial loss for frame features z_t
210
  self.fc_feature_domain_frame = nn.Linear(self.z_dim, self.z_dim)
211
  self.fc_classifier_domain_frame = nn.Linear(self.z_dim, 2)
212
 
213
- ## #------ aggregate frame-based features (frame feature --> video feature) ------#
214
- if self.frame_aggregation == 'rnn':
215
- self.bilstm = nn.LSTM(self.z_dim, self.z_dim * 2, self.f_rnn_layers, bidirectional=True, batch_first=True)
216
- self.feat_aggregated_dim = self.z_dim * 2
217
- elif self.frame_aggregation == 'trn': # 4. TRN (ECCV 2018) ==> fix segment # for both train/val
218
- self.num_bottleneck = 256 # 256
219
- self.TRN = RelationModuleMultiScale(self.z_dim, self.num_bottleneck, self.frames)
220
- self.bn_trn_S = nn.BatchNorm1d(self.num_bottleneck)
221
- self.bn_trn_T = nn.BatchNorm1d(self.num_bottleneck)
222
- self.feat_aggregated_dim = self.num_bottleneck
223
 
224
- ## adversarial loss for video features
225
  self.fc_feature_domain_video = nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim)
226
  self.fc_classifier_domain_video = nn.Linear(self.feat_aggregated_dim, 2)
227
 
228
- ## adversarial loss for each relation of features
229
- if self.frame_aggregation == 'trn':
230
- self.relation_domain_classifier_all = nn.ModuleList()
231
- for i in range(self.frames-1):
232
- relation_domain_classifier = nn.Sequential(
233
- nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim),
234
- nn.ReLU(),
235
- nn.Linear(self.feat_aggregated_dim, 2)
236
- )
237
- self.relation_domain_classifier_all += [relation_domain_classifier]
238
 
239
- ## classifier for action prediction task
240
  self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
241
-
242
- ## classifier for prediction domains
243
  self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
244
  self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
245
-
246
- ## attention option
247
- if self.use_attn == 'general':
248
- self.attn_layer = nn.Sequential(
249
- nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim),
250
- nn.Tanh(),
251
- nn.Linear(self.feat_aggregated_dim, 1)
252
- )
253
 
254
  def domain_classifier_frame(self, feat, beta):
255
  feat_fc_domain_frame = GradReverse.apply(feat, beta)
@@ -258,6 +145,7 @@ class TransferVAE_Video(nn.Module):
258
  pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame)
259
  return pred_fc_domain_frame
260
 
 
261
  def domain_classifier_video(self, feat_video, beta):
262
  feat_fc_domain_video = GradReverse.apply(feat_video, beta)
263
  feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video)
@@ -265,17 +153,19 @@ class TransferVAE_Video(nn.Module):
265
  pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video)
266
  return pred_fc_domain_video
267
 
 
268
  def domain_classifier_latent(self, f):
269
  feat_fc_domain_latent = self.fc_feature_domain_latent(f)
270
  feat_fc_domain_latent = self.relu(feat_fc_domain_latent)
271
  pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent)
272
  return pred_fc_domain_latent
273
 
 
274
  def domain_classifier_relation(self, feat_relation, beta):
275
  pred_fc_domain_relation_video = None
276
  for i in range(len(self.relation_domain_classifier_all)):
277
- feat_relation_single = feat_relation[:,i,:].squeeze(1) # 128x1x256 --> 128x256
278
- feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta) # the same beta for all relations (for now)
279
 
280
  pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single)
281
 
@@ -288,6 +178,7 @@ class TransferVAE_Video(nn.Module):
288
 
289
  return pred_fc_domain_relation_video
290
 
 
291
  def get_trans_attn(self, pred_domain):
292
  softmax = nn.Softmax(dim=1)
293
  logsoftmax = nn.LogSoftmax(dim=1)
@@ -295,6 +186,7 @@ class TransferVAE_Video(nn.Module):
295
  weights = 1 - entropy
296
  return weights
297
 
 
298
  def get_general_attn(self, feat):
299
  num_segments = feat.size()[1]
300
  feat = feat.view(-1, feat.size()[-1]) # reshape features: 128x4x256 --> (128x4)x256
@@ -303,15 +195,11 @@ class TransferVAE_Video(nn.Module):
303
  weights = F.softmax(weights, dim=1) # softmax over segments ==> 128x4x1
304
  return weights
305
 
 
306
  def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments):
307
- if self.use_attn == 'TransAttn':
308
- weights_attn = self.get_trans_attn(pred_domain)
309
- elif self.use_attn == 'general':
310
- weights_attn = self.get_general_attn(feat_fc)
311
-
312
  weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) # reshape & repeat weights (e.g. 16 x 4 x 256)
313
  feat_fc_attn = (weights_attn+1) * feat_fc
314
-
315
  return feat_fc_attn, weights_attn[:,:,0]
316
 
317
 
@@ -357,94 +245,18 @@ class TransferVAE_Video(nn.Module):
357
  f_post = f_post_list
358
  # f_mean and f_post are list if triple else not
359
  return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
 
360
 
361
  def decoder_frame(self,zf):
362
- if self.input_type == 'image':
363
- recon_x = self.decoder(zf)
364
- return recon_x
365
-
366
- if self.input_type == 'feature':
367
- zf = self.z_2_out(zf) # batch,frames,(z_dim+f_dim) -> batch,frames,fc_output_dim
368
- zf = self.relu(zf)
369
-
370
- if self.add_fc > 2:
371
- zf = self.dec_fc_layer3(zf)
372
- if self.use_bn == 'shared':
373
- zf = self.bn_dec_layer3(zf)
374
- elif self.use_bn == 'separated':
375
- zf_src = self.bn_S_dec_layer3(zf[:self.batchsize,:,:])
376
- zf_tar = self.bn_T_dec_layer3(zf[self.batchsize:,:,:])
377
- zf = torch.cat([zf_src,zf_tar],axis=0)
378
- zf = self.relu(zf)
379
-
380
- if self.add_fc > 1:
381
- zf = self.dec_fc_layer2(zf)
382
- if self.use_bn == 'shared':
383
- zf = self.bn_dec_layer2(zf)
384
- elif self.use_bn == 'separated':
385
- zf_src = self.bn_S_dec_layer2(zf[:self.batchsize,:,:])
386
- zf_tar = self.bn_T_dec_layer2(zf[self.batchsize:,:,:])
387
- zf = torch.cat([zf_src,zf_tar],axis=0)
388
- zf = self.relu(zf)
389
-
390
-
391
- zf = self.dec_fc_layer1(zf)
392
- if self.use_bn == 'shared':
393
- zf = self.bn_dec_layer2(zf)
394
- elif self.use_bn == 'separated':
395
- zf_src = self.bn_S_dec_layer2(zf[:self.batchsize,:,:])
396
- zf_tar = self.bn_T_dec_layer2(zf[self.batchsize:,:,:])
397
- zf = torch.cat([zf_src,zf_tar],axis=0)
398
- recon_x = self.relu(zf)
399
- return recon_x
400
 
401
  def encoder_frame(self, x):
402
- if self.input_type == 'image':
403
- # input x is list of length Frames [batchsize, channels, size, size]
404
- # convert it to [batchsize, frames, channels, size, size]
405
- # [batch_size, frames, channels, size, size] to [batch_size * frames, channels, size, size]
406
- x_shape = x.shape
407
- x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
408
- x_embed = self.encoder(x)[0]
409
- # to [batch_size,frames,embed_dim]
410
-
411
- return x_embed.view(x_shape[0], x_shape[1], -1)
412
-
413
-
414
- if self.input_type == 'feature':
415
- # input is [batchsize, framew, input_dim]
416
- x_embed = self.enc_fc_layer1(x)
417
- ## use batchnormalization or not (if yes whether the source and target share the same batchnormalization)
418
- if self.use_bn == 'shared':
419
- x_embed = self.bn_enc_layer1(x_embed)
420
- elif self.use_bn == 'separated':
421
- x_embed_src = self.bn_S_enc_layer1(x_embed[:self.batchsize,:,:])
422
- x_embed_tar = self.bn_T_enc_layer1(x_embed[self.batchsize:,:,:])
423
- x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0)
424
- x_embed = self.relu(x_embed)
425
-
426
- if self.add_fc > 1:
427
- x_embed = self.enc_fc_layer2(x_embed)
428
- if self.use_bn == 'shared':
429
- x_embed = self.bn_enc_layer2(x_embed)
430
- elif self.use_bn == 'separated':
431
- x_embed_src = self.bn_S_enc_layer2(x_embed[:self.batchsize,:,:])
432
- x_embed_tar = self.bn_T_enc_layer2(x_embed[self.batchsize:,:,:])
433
- x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0)
434
- x_embed = self.relu(x_embed)
435
-
436
- if self.add_fc > 2:
437
- x_embed = self.enc_fc_layer3(x_embed)
438
- if self.use_bn == 'shared':
439
- x_embed = self.bn_enc_layer3(x_embed)
440
- elif self.use_bn == 'separated':
441
- x_embed_src = self.bn_S_enc_layer3(x_embed[:self.batchsize,:,:])
442
- x_embed_tar = self.bn_T_enc_layer3(x_embed[self.batchsize:,:,:])
443
- x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0)
444
- x_embed = self.relu(x_embed)
445
-
446
- ## [batchsize, frame, output_dim]
447
- return x_embed
448
 
449
 
450
  def reparameterize(self, mean, logvar, random_sampling=True):
@@ -458,7 +270,7 @@ class TransferVAE_Video(nn.Module):
458
  return mean
459
 
460
  def sample_z_prior_train(self, z_post, random_sampling=True):
461
- z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim]
462
  z_means = None
463
  z_logvars = None
464
  batch_size = z_post.shape[0]
@@ -526,77 +338,17 @@ class TransferVAE_Video(nn.Module):
526
  return z_means, z_logvars, z_out
527
 
528
  def forward(self, x, beta):
529
- # beta [beta_relation, beta_video, beta_frame]
530
- f_mean, f_logvar, f_post, z_mean_post, z_logvar_post, z_post = self.encode_and_sample_post(x)
531
- if self.prior_sample == 'random':
532
- z_mean_prior, z_logvar_prior, z_prior = self.sample_z(z_post.size(0),random_sampling=False)
533
- elif self.prior_sample == 'post':
534
- z_mean_prior, z_logvar_prior, z_prior = self.sample_z_prior_train(z_post, random_sampling=False)
535
-
536
 
537
  if isinstance(f_post, list):
538
  f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
539
  else:
540
  f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
541
- zf = torch.cat((z_post, f_expand), dim=2) # batch,frames,(z_dim+f_dim)
542
 
543
- ## reconcstruct x
544
  recon_x = self.decoder_frame(zf)
545
 
546
- ## For constraints on z_post [batch,frame,z_dim] and f_post [batch,f_dim]
547
- pred_domain_all = [] # list save domain predictions (1) z_post (frame level) (2) each z_post_relation (if trn) (3) z_post (video level) (4)f_post
548
-
549
- #1. adversarial on z_post (frame level)
550
- z_post_feat = z_post.view(-1, z_post.size()[-1]) # e.g. 32 x 5 x 2048 --> 160 x 2048
551
- z_post_feat = self.dropout_f(z_post_feat)
552
- pred_fc_domain_frame = self.domain_classifier_frame(z_post_feat, beta[2])
553
- pred_fc_domain_frame = pred_fc_domain_frame.view((z_post.size(0), self.frames) + pred_fc_domain_frame.size()[-1:])
554
- pred_domain_all.append(pred_fc_domain_frame)
555
-
556
- #2 adversarial on z_post (video level, relation level if trn is used)
557
-
558
- if self.frame_aggregation == 'rnn':
559
- self.bilstm.flatten_parameters()
560
- z_post_video_feat, _ = self.bilstm(z_post)
561
- backward = z_post_video_feat[:, 0, self.z_dim:2 * self.z_dim]
562
- frontal = z_post_video_feat[:, self.frames - 1, 0:self.z_dim]
563
- z_post_video_feat = torch.cat((frontal, backward), dim=1)
564
- pred_fc_domain_relation = []
565
- pred_domain_all.append(pred_fc_domain_relation)
566
-
567
- elif self.frame_aggregation == 'trn':
568
- z_post_video_relation = self.TRN(z_post) ## [batch, frame-1, self.feat_aggregated_dim]
569
-
570
- # adversarial branch for each relation
571
- pred_fc_domain_relation = self.domain_classifier_relation(z_post_video_relation, beta[0])
572
- pred_domain_all.append(pred_fc_domain_relation.view((z_post.size(0), z_post_video_relation.size()[1]) + pred_fc_domain_relation.size()[-1:]))
573
-
574
- # transferable attention
575
- if self.use_attn != 'none': # get the attention weighting
576
- z_post_video_relation_attn, _ = self.get_attn_feat_relation(z_post_video_relation, pred_fc_domain_relation, self.frames)
577
-
578
- # sum up relation features (ignore 1-relation)
579
- z_post_video_feat = torch.sum(z_post_video_relation_attn, 1)
580
-
581
-
582
- z_post_video_feat = self.dropout_v(z_post_video_feat)
583
-
584
- pred_fc_domain_video = self.domain_classifier_video(z_post_video_feat, beta[1])
585
- pred_fc_domain_video = pred_fc_domain_video.view((z_post.size(0),) + pred_fc_domain_video.size()[-1:])
586
- pred_domain_all.append(pred_fc_domain_video)
587
-
588
-
589
- #3. video prediction
590
- pred_video_class = self.pred_classifier_video(z_post_video_feat)
591
-
592
- #4. domain prediction on f
593
- if isinstance(f_post, list):
594
- pred_fc_domain_latent = self.domain_classifier_latent(f_post[0])
595
- else:
596
- pred_fc_domain_latent = self.domain_classifier_latent(f_post)
597
- pred_domain_all.append(pred_fc_domain_latent)
598
-
599
- return f_mean, f_logvar, f_post, z_mean_post, z_logvar_post, z_post, z_mean_prior, z_logvar_prior, z_prior, recon_x, pred_domain_all, pred_video_class
600
 
601
 
602
  def name2seq(file_name):
@@ -700,6 +452,12 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
700
  plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
701
 
702
 
 
 
 
 
 
 
703
  def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
704
 
705
  # == Source Avatar ==
@@ -760,15 +518,9 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
760
  x = torch.cat((images_source, images_target), dim=0)
761
 
762
 
763
- # == Load Model ==
764
- model = TransferVAE_Video(opt)
765
- model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
766
- model.eval()
767
-
768
-
769
  # == Forward ==
770
  with torch.no_grad():
771
- f_mean, f_logvar, f_post, z_post_mean, z_post_logvar, z_post, z_prior_mean, z_prior_logvar, z_prior, recon_x, pred_domain_all, pred_video_class = model(x, [0]*3)
772
 
773
  src_orig_sample = x[0, :, :, :, :]
774
  src_recon_sample = recon_x[0, :, :, :, :]
@@ -824,12 +576,12 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
824
  gr.Interface(
825
  run,
826
  inputs=[
827
- gr.Textbox(value="Source Avatar - Human", interactive=False),
828
  gr.Radio(choices=["slash", "spellcard", "walk"], value="slash"),
829
  gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"),
830
  gr.Radio(choices=["brown", "blue", "white"], value="brown"),
831
  gr.Radio(choices=["white", "golden", "red", "silver"], value="white"),
832
- gr.Textbox(value="Target Avatar - Alien", interactive=False),
833
  gr.Radio(choices=["slash", "spellcard", "walk"], value="walk"),
834
  gr.Radio(choices=["violet", "silver", "purple", "grey", "golden"], value="golden"),
835
  gr.Radio(choices=["grey", "khaki", "linen", "ocre"], value="ocre"),
 
1
  import gradio as gr
2
 
 
3
  import cv2
4
  import imageio
5
  import math
 
28
  self.relations_scales.append(relations_scale)
29
  self.subsample_scales.append(min(self.subsample_num, len(relations_scale)))
30
  self.num_frames = num_frames
31
+ self.fc_fusion_scales = nn.ModuleList()
32
  for i in range(len(self.scales)):
33
  scale = self.scales[i]
34
  fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU())
 
59
  return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  class GradReverse(Function):
63
  @staticmethod
64
  def forward(ctx, x, beta):
 
73
 
74
  class TransferVAE_Video(nn.Module):
75
 
76
+ def __init__(self):
77
  super(TransferVAE_Video, self).__init__()
78
+ self.f_dim = 512
79
+ self.z_dim = 512
80
+ self.fc_dim = 1024
81
+ self.channels = 3
82
+ self.frames = 8
83
+ self.batch_size = 128
84
+ self.dropout_rate = 0.5
85
+ self.num_class = 15
86
+ self.prior_sample = 'random'
 
 
 
 
87
 
88
+ import dcgan_64
89
+ self.encoder = dcgan_64.encoder(self.fc_dim, self.channels)
90
+ self.decoder = dcgan_64.decoder_woSkip(self.z_dim + self.f_dim, self.channels)
91
+ self.fc_output_dim = self.fc_dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
93
  self.relu = nn.LeakyReLU(0.1)
94
  self.dropout_f = nn.Dropout(p=self.dropout_rate)
95
  self.dropout_v = nn.Dropout(p=self.dropout_rate)
 
96
 
97
+ self.hidden_dim = 512
98
+ self.f_rnn_layers = 1
 
 
 
99
 
 
100
  self.z_prior_lstm_ly1 = nn.LSTMCell(self.z_dim, self.hidden_dim)
101
  self.z_prior_lstm_ly2 = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
102
 
103
  self.z_prior_mean = nn.Linear(self.hidden_dim, self.z_dim)
104
  self.z_prior_logvar = nn.Linear(self.hidden_dim, self.z_dim)
105
 
 
 
106
  self.z_lstm = nn.LSTM(self.fc_output_dim, self.hidden_dim, self.f_rnn_layers, bidirectional=True, batch_first=True)
107
  self.f_mean = nn.Linear(self.hidden_dim * 2, self.f_dim)
108
  self.f_logvar = nn.Linear(self.hidden_dim * 2, self.f_dim)
109
 
110
  self.z_rnn = nn.RNN(self.hidden_dim * 2, self.hidden_dim, batch_first=True)
 
111
  self.z_mean = nn.Linear(self.hidden_dim, self.z_dim)
112
  self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim)
 
113
 
 
 
 
114
  self.fc_feature_domain_frame = nn.Linear(self.z_dim, self.z_dim)
115
  self.fc_classifier_domain_frame = nn.Linear(self.z_dim, 2)
116
 
117
+ self.num_bottleneck = 256
118
+ self.TRN = RelationModuleMultiScale(self.z_dim, self.num_bottleneck, self.frames)
119
+ self.bn_trn_S = nn.BatchNorm1d(self.num_bottleneck)
120
+ self.bn_trn_T = nn.BatchNorm1d(self.num_bottleneck)
121
+ self.feat_aggregated_dim = self.num_bottleneck
 
 
 
 
 
122
 
 
123
  self.fc_feature_domain_video = nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim)
124
  self.fc_classifier_domain_video = nn.Linear(self.feat_aggregated_dim, 2)
125
 
126
+ self.relation_domain_classifier_all = nn.ModuleList()
127
+ for i in range(self.frames-1):
128
+ relation_domain_classifier = nn.Sequential(
129
+ nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim),
130
+ nn.ReLU(),
131
+ nn.Linear(self.feat_aggregated_dim, 2)
132
+ )
133
+ self.relation_domain_classifier_all += [relation_domain_classifier]
 
 
134
 
 
135
  self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
136
+
 
137
  self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
138
  self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
139
+
 
 
 
 
 
 
 
140
 
141
  def domain_classifier_frame(self, feat, beta):
142
  feat_fc_domain_frame = GradReverse.apply(feat, beta)
 
145
  pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame)
146
  return pred_fc_domain_frame
147
 
148
+
149
  def domain_classifier_video(self, feat_video, beta):
150
  feat_fc_domain_video = GradReverse.apply(feat_video, beta)
151
  feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video)
 
153
  pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video)
154
  return pred_fc_domain_video
155
 
156
+
157
  def domain_classifier_latent(self, f):
158
  feat_fc_domain_latent = self.fc_feature_domain_latent(f)
159
  feat_fc_domain_latent = self.relu(feat_fc_domain_latent)
160
  pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent)
161
  return pred_fc_domain_latent
162
 
163
+
164
  def domain_classifier_relation(self, feat_relation, beta):
165
  pred_fc_domain_relation_video = None
166
  for i in range(len(self.relation_domain_classifier_all)):
167
+ feat_relation_single = feat_relation[:,i,:].squeeze(1)
168
+ feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta)
169
 
170
  pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single)
171
 
 
178
 
179
  return pred_fc_domain_relation_video
180
 
181
+
182
  def get_trans_attn(self, pred_domain):
183
  softmax = nn.Softmax(dim=1)
184
  logsoftmax = nn.LogSoftmax(dim=1)
 
186
  weights = 1 - entropy
187
  return weights
188
 
189
+
190
  def get_general_attn(self, feat):
191
  num_segments = feat.size()[1]
192
  feat = feat.view(-1, feat.size()[-1]) # reshape features: 128x4x256 --> (128x4)x256
 
195
  weights = F.softmax(weights, dim=1) # softmax over segments ==> 128x4x1
196
  return weights
197
 
198
+
199
  def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments):
200
+ weights_attn = self.get_trans_attn(pred_domain)
 
 
 
 
201
  weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) # reshape & repeat weights (e.g. 16 x 4 x 256)
202
  feat_fc_attn = (weights_attn+1) * feat_fc
 
203
  return feat_fc_attn, weights_attn[:,:,0]
204
 
205
 
 
245
  f_post = f_post_list
246
  # f_mean and f_post are list if triple else not
247
  return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
248
+
249
 
250
  def decoder_frame(self,zf):
251
+ recon_x = self.decoder(zf)
252
+ return recon_x
253
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  def encoder_frame(self, x):
256
+ x_shape = x.shape
257
+ x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
258
+ x_embed = self.encoder(x)[0]
259
+ return x_embed.view(x_shape[0], x_shape[1], -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
 
262
  def reparameterize(self, mean, logvar, random_sampling=True):
 
270
  return mean
271
 
272
  def sample_z_prior_train(self, z_post, random_sampling=True):
273
+ z_out = None
274
  z_means = None
275
  z_logvars = None
276
  batch_size = z_post.shape[0]
 
338
  return z_means, z_logvars, z_out
339
 
340
  def forward(self, x, beta):
341
+ _, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
 
 
 
 
 
 
342
 
343
  if isinstance(f_post, list):
344
  f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
345
  else:
346
  f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
347
+ zf = torch.cat((z_post, f_expand), dim=2)
348
 
 
349
  recon_x = self.decoder_frame(zf)
350
 
351
+ return f_post, z_post, recon_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
 
354
  def name2seq(file_name):
 
452
  plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
453
 
454
 
455
+ # == Load Model ==
456
+ model = TransferVAE_Video(opt)
457
+ model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
458
+ model.eval()
459
+
460
+
461
  def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target):
462
 
463
  # == Source Avatar ==
 
518
  x = torch.cat((images_source, images_target), dim=0)
519
 
520
 
 
 
 
 
 
 
521
  # == Forward ==
522
  with torch.no_grad():
523
+ f_post, z_post, recon_x = model(x, [0]*3)
524
 
525
  src_orig_sample = x[0, :, :, :, :]
526
  src_recon_sample = recon_x[0, :, :, :, :]
 
576
  gr.Interface(
577
  run,
578
  inputs=[
579
+ gr.Textbox(value="Source Avatar - Human", show_label=False, interactive=False),
580
  gr.Radio(choices=["slash", "spellcard", "walk"], value="slash"),
581
  gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"),
582
  gr.Radio(choices=["brown", "blue", "white"], value="brown"),
583
  gr.Radio(choices=["white", "golden", "red", "silver"], value="white"),
584
+ gr.Textbox(value="Target Avatar - Alien", show_label=False, interactive=False),
585
  gr.Radio(choices=["slash", "spellcard", "walk"], value="walk"),
586
  gr.Radio(choices=["violet", "silver", "purple", "grey", "golden"], value="golden"),
587
  gr.Radio(choices=["grey", "khaki", "linen", "ocre"], value="ocre"),