ldkong commited on
Commit
2db6196
β€’
1 Parent(s): 6524c35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -12,6 +12,10 @@ import torch.nn as nn
12
  import torch.nn.functional as F
13
 
14
 
 
 
 
 
15
  class RelationModuleMultiScale(torch.nn.Module):
16
 
17
  def __init__(self, img_feature_dim, num_bottleneck, num_frames):
@@ -264,11 +268,15 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
264
  plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
265
 
266
 
 
 
267
  # == Load Model ==
268
  model = TransferVAE_Video()
269
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
270
  model.eval()
271
-
 
 
272
 
273
  def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
274
 
@@ -330,10 +338,12 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
330
  x = torch.cat((images_source, images_target), dim=0)
331
 
332
 
 
333
  # == Forward ==
334
  with torch.no_grad():
335
  f_post, z_post, recon_x = model(x, [0]*3)
336
-
 
337
  src_orig_sample = x[0, :, :, :, :]
338
  src_recon_sample = recon_x[0, :, :, :, :]
339
  src_f_post = f_post[0, :].unsqueeze(0)
@@ -378,7 +388,9 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
378
  recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
379
  tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
380
 
 
381
  MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
 
382
 
383
  a = concat('MyPlot_')
384
 
 
12
  import torch.nn.functional as F
13
 
14
 
15
+ def log_gpu_memory():
16
+ print(subprocess.check_output('nvidia-smi').decode('utf-8'))
17
+
18
+
19
  class RelationModuleMultiScale(torch.nn.Module):
20
 
21
  def __init__(self, img_feature_dim, num_bottleneck, num_frames):
 
268
  plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
269
 
270
 
271
+ log_gpu_memory()
272
+
273
  # == Load Model ==
274
  model = TransferVAE_Video()
275
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
276
  model.eval()
277
+
278
+ log_gpu_memory()
279
+
280
 
281
  def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
282
 
 
338
  x = torch.cat((images_source, images_target), dim=0)
339
 
340
 
341
+ log_gpu_memory()
342
  # == Forward ==
343
  with torch.no_grad():
344
  f_post, z_post, recon_x = model(x, [0]*3)
345
+ log_gpu_memory()
346
+
347
  src_orig_sample = x[0, :, :, :, :]
348
  src_recon_sample = recon_x[0, :, :, :, :]
349
  src_f_post = f_post[0, :].unsqueeze(0)
 
388
  recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
389
  tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
390
 
391
+ log_gpu_memory()
392
  MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
393
+ log_gpu_memory()
394
 
395
  a = concat('MyPlot_')
396