ldkong commited on
Commit
af4f972
β€’
1 Parent(s): 3c3a705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -157
app.py CHANGED
@@ -59,18 +59,6 @@ class RelationModuleMultiScale(torch.nn.Module):
59
  return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
60
 
61
 
62
- class GradReverse(Function):
63
- @staticmethod
64
- def forward(ctx, x, beta):
65
- ctx.beta = beta
66
- return x.view_as(x)
67
-
68
- @staticmethod
69
- def backward(ctx, grad_output):
70
- grad_input = grad_output.neg() * ctx.beta
71
- return grad_input, None
72
-
73
-
74
  class TransferVAE_Video(nn.Module):
75
 
76
  def __init__(self):
@@ -133,86 +121,18 @@ class TransferVAE_Video(nn.Module):
133
  self.relation_domain_classifier_all += [relation_domain_classifier]
134
 
135
  self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
136
-
137
  self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
138
  self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
139
 
140
-
141
- def domain_classifier_frame(self, feat, beta):
142
- feat_fc_domain_frame = GradReverse.apply(feat, beta)
143
- feat_fc_domain_frame = self.fc_feature_domain_frame(feat_fc_domain_frame)
144
- feat_fc_domain_frame = self.relu(feat_fc_domain_frame)
145
- pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame)
146
- return pred_fc_domain_frame
147
-
148
-
149
- def domain_classifier_video(self, feat_video, beta):
150
- feat_fc_domain_video = GradReverse.apply(feat_video, beta)
151
- feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video)
152
- feat_fc_domain_video = self.relu(feat_fc_domain_video)
153
- pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video)
154
- return pred_fc_domain_video
155
-
156
-
157
- def domain_classifier_latent(self, f):
158
- feat_fc_domain_latent = self.fc_feature_domain_latent(f)
159
- feat_fc_domain_latent = self.relu(feat_fc_domain_latent)
160
- pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent)
161
- return pred_fc_domain_latent
162
-
163
-
164
- def domain_classifier_relation(self, feat_relation, beta):
165
- pred_fc_domain_relation_video = None
166
- for i in range(len(self.relation_domain_classifier_all)):
167
- feat_relation_single = feat_relation[:,i,:].squeeze(1)
168
- feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta)
169
-
170
- pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single)
171
-
172
- if pred_fc_domain_relation_video is None:
173
- pred_fc_domain_relation_video = pred_fc_domain_relation_single.view(-1,1,2)
174
- else:
175
- pred_fc_domain_relation_video = torch.cat((pred_fc_domain_relation_video, pred_fc_domain_relation_single.view(-1,1,2)), 1)
176
-
177
- pred_fc_domain_relation_video = pred_fc_domain_relation_video.view(-1,2)
178
-
179
- return pred_fc_domain_relation_video
180
-
181
-
182
- def get_trans_attn(self, pred_domain):
183
- softmax = nn.Softmax(dim=1)
184
- logsoftmax = nn.LogSoftmax(dim=1)
185
- entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1)
186
- weights = 1 - entropy
187
- return weights
188
-
189
-
190
- def get_general_attn(self, feat):
191
- num_segments = feat.size()[1]
192
- feat = feat.view(-1, feat.size()[-1]) # reshape features: 128x4x256 --> (128x4)x256
193
- weights = self.attn_layer(feat) # e.g. (128x4)x1
194
- weights = weights.view(-1, num_segments, weights.size()[-1]) # reshape attention weights: (128x4)x1 --> 128x4x1
195
- weights = F.softmax(weights, dim=1) # softmax over segments ==> 128x4x1
196
- return weights
197
-
198
-
199
- def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments):
200
- weights_attn = self.get_trans_attn(pred_domain)
201
- weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) # reshape & repeat weights (e.g. 16 x 4 x 256)
202
- feat_fc_attn = (weights_attn+1) * feat_fc
203
- return feat_fc_attn, weights_attn[:,:,0]
204
-
205
 
206
  def encode_and_sample_post(self, x):
207
  if isinstance(x, list):
208
  conv_x = self.encoder_frame(x[0])
209
  else:
210
  conv_x = self.encoder_frame(x)
211
-
212
- # pass the bidirectional lstm
213
  lstm_out, _ = self.z_lstm(conv_x)
214
-
215
- # get f:
216
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
217
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
218
  lstm_out_f = torch.cat((frontal, backward), dim=1)
@@ -220,7 +140,6 @@ class TransferVAE_Video(nn.Module):
220
  f_logvar = self.f_logvar(lstm_out_f)
221
  f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
222
 
223
- # pass to one direction rnn
224
  features, _ = self.z_rnn(lstm_out)
225
  z_mean = self.z_mean(features)
226
  z_logvar = self.z_logvar(features)
@@ -232,7 +151,6 @@ class TransferVAE_Video(nn.Module):
232
  for t in range(1,3,1):
233
  conv_x = self.encoder_frame(x[t])
234
  lstm_out, _ = self.z_lstm(conv_x)
235
- # get f:
236
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
237
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
238
  lstm_out_f = torch.cat((frontal, backward), dim=1)
@@ -243,7 +161,6 @@ class TransferVAE_Video(nn.Module):
243
  f_post_list.append(f_post)
244
  f_mean = f_mean_list
245
  f_post = f_post_list
246
- # f_mean and f_post are list if triple else not
247
  return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
248
 
249
 
@@ -260,7 +177,6 @@ class TransferVAE_Video(nn.Module):
260
 
261
 
262
  def reparameterize(self, mean, logvar, random_sampling=True):
263
- # Reparametrization occurs only if random sampling is set to true, otherwise mean is returned
264
  if random_sampling is True:
265
  eps = torch.randn_like(logvar)
266
  std = torch.exp(0.5 * logvar)
@@ -269,88 +185,20 @@ class TransferVAE_Video(nn.Module):
269
  else:
270
  return mean
271
 
272
- def sample_z_prior_train(self, z_post, random_sampling=True):
273
- z_out = None
274
- z_means = None
275
- z_logvars = None
276
- batch_size = z_post.shape[0]
277
-
278
- z_t = torch.zeros(batch_size, self.z_dim).cpu()
279
- h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
280
- c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
281
- h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
282
- c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
283
-
284
- for i in range(self.frames):
285
- # two layer LSTM and two one-layer FC
286
- h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
287
- h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
288
-
289
- z_mean_t = self.z_prior_mean(h_t_ly2)
290
- z_logvar_t = self.z_prior_logvar(h_t_ly2)
291
- z_prior = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
292
- if z_out is None:
293
- # If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
294
- z_out = z_prior.unsqueeze(1)
295
- z_means = z_mean_t.unsqueeze(1)
296
- z_logvars = z_logvar_t.unsqueeze(1)
297
- else:
298
- # If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
299
- z_out = torch.cat((z_out, z_prior.unsqueeze(1)), dim=1)
300
- z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
301
- z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
302
- z_t = z_post[:,i,:]
303
- return z_means, z_logvars, z_out
304
-
305
- # If random sampling is true, reparametrization occurs else z_t is just set to the mean
306
- def sample_z(self, batch_size, random_sampling=True):
307
- z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim]
308
- z_means = None
309
- z_logvars = None
310
-
311
- # All states are initially set to 0, especially z_0 = 0
312
- z_t = torch.zeros(batch_size, self.z_dim).cpu()
313
- # z_mean_t = torch.zeros(batch_size, self.z_dim)
314
- # z_logvar_t = torch.zeros(batch_size, self.z_dim)
315
- h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
316
- c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
317
- h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
318
- c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
319
- for _ in range(self.frames):
320
- # h_t, c_t = self.z_prior_lstm(z_t, (h_t, c_t))
321
- # two layer LSTM and two one-layer FC
322
- h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
323
- h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
324
-
325
- z_mean_t = self.z_prior_mean(h_t_ly2)
326
- z_logvar_t = self.z_prior_logvar(h_t_ly2)
327
- z_t = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
328
- if z_out is None:
329
- # If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
330
- z_out = z_t.unsqueeze(1)
331
- z_means = z_mean_t.unsqueeze(1)
332
- z_logvars = z_logvar_t.unsqueeze(1)
333
- else:
334
- # If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
335
- z_out = torch.cat((z_out, z_t.unsqueeze(1)), dim=1)
336
- z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
337
- z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
338
- return z_means, z_logvars, z_out
339
 
340
  def forward(self, x, beta):
341
  _, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
342
-
343
  if isinstance(f_post, list):
344
  f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
345
  else:
346
  f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
347
  zf = torch.cat((z_post, f_expand), dim=2)
348
-
349
  recon_x = self.decoder_frame(zf)
350
-
351
  return f_post, z_post, recon_x
352
 
353
 
 
 
354
  def name2seq(file_name):
355
  images = []
356
 
@@ -520,7 +368,7 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
520
 
521
  # == Forward ==
522
  with torch.no_grad():
523
- f_post, z_post, recon_x = model(x, [0]*3)
524
 
525
  src_orig_sample = x[0, :, :, :, :]
526
  src_recon_sample = recon_x[0, :, :, :, :]
 
59
  return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  class TransferVAE_Video(nn.Module):
63
 
64
  def __init__(self):
 
121
  self.relation_domain_classifier_all += [relation_domain_classifier]
122
 
123
  self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
 
124
  self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
125
  self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def encode_and_sample_post(self, x):
129
  if isinstance(x, list):
130
  conv_x = self.encoder_frame(x[0])
131
  else:
132
  conv_x = self.encoder_frame(x)
133
+
 
134
  lstm_out, _ = self.z_lstm(conv_x)
135
+
 
136
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
137
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
138
  lstm_out_f = torch.cat((frontal, backward), dim=1)
 
140
  f_logvar = self.f_logvar(lstm_out_f)
141
  f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
142
 
 
143
  features, _ = self.z_rnn(lstm_out)
144
  z_mean = self.z_mean(features)
145
  z_logvar = self.z_logvar(features)
 
151
  for t in range(1,3,1):
152
  conv_x = self.encoder_frame(x[t])
153
  lstm_out, _ = self.z_lstm(conv_x)
 
154
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
155
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
156
  lstm_out_f = torch.cat((frontal, backward), dim=1)
 
161
  f_post_list.append(f_post)
162
  f_mean = f_mean_list
163
  f_post = f_post_list
 
164
  return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
165
 
166
 
 
177
 
178
 
179
  def reparameterize(self, mean, logvar, random_sampling=True):
 
180
  if random_sampling is True:
181
  eps = torch.randn_like(logvar)
182
  std = torch.exp(0.5 * logvar)
 
185
  else:
186
  return mean
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def forward(self, x, beta):
190
  _, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
 
191
  if isinstance(f_post, list):
192
  f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
193
  else:
194
  f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
195
  zf = torch.cat((z_post, f_expand), dim=2)
 
196
  recon_x = self.decoder_frame(zf)
 
197
  return f_post, z_post, recon_x
198
 
199
 
200
+
201
+
202
  def name2seq(file_name):
203
  images = []
204
 
 
368
 
369
  # == Forward ==
370
  with torch.no_grad():
371
+ f_post, z_post, recon_x = model(x, [0]*3)
372
 
373
  src_orig_sample = x[0, :, :, :, :]
374
  src_recon_sample = recon_x[0, :, :, :, :]