ldkong commited on
Commit
1b79989
β€’
1 Parent(s): 1ee9227

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -39
app.py CHANGED
@@ -13,10 +13,6 @@ import torch.nn as nn
13
  import torch.nn.functional as F
14
 
15
 
16
- def log_gpu_memory():
17
- print(subprocess.check_output('nvidia-smi').decode('utf-8'))
18
-
19
-
20
  class RelationModuleMultiScale(torch.nn.Module):
21
 
22
  def __init__(self, img_feature_dim, num_bottleneck, num_frames):
@@ -129,42 +125,19 @@ class TransferVAE_Video(nn.Module):
129
 
130
 
131
  def encode_and_sample_post(self, x):
132
- if isinstance(x, list):
133
- conv_x = self.encoder_frame(x[0])
134
- else:
135
- conv_x = self.encoder_frame(x)
136
-
137
  lstm_out, _ = self.z_lstm(conv_x)
138
-
139
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
140
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
141
  lstm_out_f = torch.cat((frontal, backward), dim=1)
142
  f_mean = self.f_mean(lstm_out_f)
143
  f_logvar = self.f_logvar(lstm_out_f)
144
  f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
145
-
146
  features, _ = self.z_rnn(lstm_out)
147
  z_mean = self.z_mean(features)
148
  z_logvar = self.z_logvar(features)
149
  z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
150
-
151
- if isinstance(x, list):
152
- f_mean_list = [f_mean]
153
- f_post_list = [f_post]
154
- for t in range(1,3,1):
155
- conv_x = self.encoder_frame(x[t])
156
- lstm_out, _ = self.z_lstm(conv_x)
157
- backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
158
- frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
159
- lstm_out_f = torch.cat((frontal, backward), dim=1)
160
- f_mean = self.f_mean(lstm_out_f)
161
- f_logvar = self.f_logvar(lstm_out_f)
162
- f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
163
- f_mean_list.append(f_mean)
164
- f_post_list.append(f_post)
165
- f_mean = f_mean_list
166
- f_post = f_post_list
167
- return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
168
 
169
 
170
  def decoder_frame(self,zf):
@@ -190,7 +163,7 @@ class TransferVAE_Video(nn.Module):
190
 
191
 
192
  def forward(self, x, beta):
193
- _, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
194
  if isinstance(f_post, list):
195
  f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
196
  else:
@@ -269,15 +242,11 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
269
  plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
270
 
271
 
272
- log_gpu_memory()
273
-
274
  # == Load Model ==
275
  model = TransferVAE_Video()
276
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
277
  model.eval()
278
 
279
- log_gpu_memory()
280
-
281
 
282
  def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
283
 
@@ -338,12 +307,11 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
338
  images_target = name2seq(file_name_target)
339
  x = torch.cat((images_source, images_target), dim=0)
340
 
341
-
342
- log_gpu_memory()
343
  # == Forward ==
344
  with torch.no_grad():
345
  f_post, z_post, recon_x = model(x, [0]*3)
346
- log_gpu_memory()
347
 
348
  src_orig_sample = x[0, :, :, :, :]
349
  src_recon_sample = recon_x[0, :, :, :, :]
@@ -389,9 +357,7 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
389
  recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
390
  tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
391
 
392
- log_gpu_memory()
393
  MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
394
- log_gpu_memory()
395
 
396
  a = concat('MyPlot_')
397
 
 
13
  import torch.nn.functional as F
14
 
15
 
 
 
 
 
16
  class RelationModuleMultiScale(torch.nn.Module):
17
 
18
  def __init__(self, img_feature_dim, num_bottleneck, num_frames):
 
125
 
126
 
127
  def encode_and_sample_post(self, x):
128
+ conv_x = self.encoder_frame(x)
 
 
 
 
129
  lstm_out, _ = self.z_lstm(conv_x)
 
130
  backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
131
  frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
132
  lstm_out_f = torch.cat((frontal, backward), dim=1)
133
  f_mean = self.f_mean(lstm_out_f)
134
  f_logvar = self.f_logvar(lstm_out_f)
135
  f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
 
136
  features, _ = self.z_rnn(lstm_out)
137
  z_mean = self.z_mean(features)
138
  z_logvar = self.z_logvar(features)
139
  z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
140
+ return f_post, z_post
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  def decoder_frame(self,zf):
 
163
 
164
 
165
  def forward(self, x, beta):
166
+ f_post, z_post = self.encode_and_sample_post(x)
167
  if isinstance(f_post, list):
168
  f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
169
  else:
 
242
  plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
243
 
244
 
 
 
245
  # == Load Model ==
246
  model = TransferVAE_Video()
247
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
248
  model.eval()
249
 
 
 
250
 
251
  def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
252
 
 
307
  images_target = name2seq(file_name_target)
308
  x = torch.cat((images_source, images_target), dim=0)
309
 
310
+
 
311
  # == Forward ==
312
  with torch.no_grad():
313
  f_post, z_post, recon_x = model(x, [0]*3)
314
+
315
 
316
  src_orig_sample = x[0, :, :, :, :]
317
  src_recon_sample = recon_x[0, :, :, :, :]
 
357
  recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
358
  tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
359
 
 
360
  MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
 
361
 
362
  a = concat('MyPlot_')
363