ldkong commited on
Commit
159a124
β€’
1 Parent(s): af4f972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -5,13 +5,11 @@ import imageio
5
  import math
6
  from math import ceil
7
  import matplotlib.pyplot as plt
8
- import matplotlib.animation as animation
9
  import numpy as np
10
  from PIL import Image
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
- from torch.autograd import Function
15
 
16
 
17
  class RelationModuleMultiScale(torch.nn.Module):
@@ -301,7 +299,7 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
301
 
302
 
303
  # == Load Model ==
304
- model = TransferVAE_Video(opt)
305
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
306
  model.eval()
307
 
@@ -393,24 +391,24 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
393
  tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
394
 
395
  # Zt
396
- f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
397
  zf_src = torch.cat((src_z_post, f_expand_src), dim=2)
398
  recon_x_src = model.decoder_frame(zf_src)
399
  src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
400
 
401
- f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
402
- zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2) # batch,frames,(z_dim+f_dim)
403
  recon_x_tar = model.decoder_frame(zf_tar)
404
  tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
405
 
406
  # Zf_Zt
407
- f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
408
- zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2) # batch,frames,(z_dim+f_dim)
409
  recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt)
410
  src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
411
 
412
- f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8, opt.f_dim)
413
- zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2) # batch,frames,(z_dim+f_dim)
414
  recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
415
  tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
416
 
 
5
  import math
6
  from math import ceil
7
  import matplotlib.pyplot as plt
 
8
  import numpy as np
9
  from PIL import Image
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
 
13
 
14
 
15
  class RelationModuleMultiScale(torch.nn.Module):
 
299
 
300
 
301
  # == Load Model ==
302
+ model = TransferVAE_Video()
303
  model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
304
  model.eval()
305
 
 
391
  tar_recon = tar_recon_sample[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
392
 
393
  # Zt
394
+ f_expand_src = 0 * src_f_post.unsqueeze(1).expand(-1, 8, 512)
395
  zf_src = torch.cat((src_z_post, f_expand_src), dim=2)
396
  recon_x_src = model.decoder_frame(zf_src)
397
  src_Zt = recon_x_src.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
398
 
399
+ f_expand_tar = 0 * tar_f_post.unsqueeze(1).expand(-1, 8, 512)
400
+ zf_tar = torch.cat((tar_z_post, f_expand_tar), dim=2)
401
  recon_x_tar = model.decoder_frame(zf_tar)
402
  tar_Zt = recon_x_tar.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
403
 
404
  # Zf_Zt
405
+ f_expand_src = src_f_post.unsqueeze(1).expand(-1, 8, 512)
406
+ zf_srcZf_tarZt = torch.cat((tar_z_post, f_expand_src), dim=2)
407
  recon_x_srcZf_tarZt = model.decoder_frame(zf_srcZf_tarZt)
408
  src_Zf_tar_Zt = recon_x_srcZf_tarZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
409
 
410
+ f_expand_tar = tar_f_post.unsqueeze(1).expand(-1, 8, 512)
411
+ zf_tarZf_srcZt = torch.cat((src_z_post, f_expand_tar), dim=2)
412
  recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
413
  tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
414