Update app.py
Browse files
app.py
CHANGED
@@ -13,10 +13,6 @@ import torch.nn as nn
|
|
13 |
import torch.nn.functional as F
|
14 |
|
15 |
|
16 |
-
def log_gpu_memory():
|
17 |
-
print(subprocess.check_output('nvidia-smi').decode('utf-8'))
|
18 |
-
|
19 |
-
|
20 |
class RelationModuleMultiScale(torch.nn.Module):
|
21 |
|
22 |
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
@@ -129,42 +125,19 @@ class TransferVAE_Video(nn.Module):
|
|
129 |
|
130 |
|
131 |
def encode_and_sample_post(self, x):
|
132 |
-
|
133 |
-
conv_x = self.encoder_frame(x[0])
|
134 |
-
else:
|
135 |
-
conv_x = self.encoder_frame(x)
|
136 |
-
|
137 |
lstm_out, _ = self.z_lstm(conv_x)
|
138 |
-
|
139 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
140 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
141 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
142 |
f_mean = self.f_mean(lstm_out_f)
|
143 |
f_logvar = self.f_logvar(lstm_out_f)
|
144 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
145 |
-
|
146 |
features, _ = self.z_rnn(lstm_out)
|
147 |
z_mean = self.z_mean(features)
|
148 |
z_logvar = self.z_logvar(features)
|
149 |
z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
|
150 |
-
|
151 |
-
if isinstance(x, list):
|
152 |
-
f_mean_list = [f_mean]
|
153 |
-
f_post_list = [f_post]
|
154 |
-
for t in range(1,3,1):
|
155 |
-
conv_x = self.encoder_frame(x[t])
|
156 |
-
lstm_out, _ = self.z_lstm(conv_x)
|
157 |
-
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
158 |
-
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
159 |
-
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
160 |
-
f_mean = self.f_mean(lstm_out_f)
|
161 |
-
f_logvar = self.f_logvar(lstm_out_f)
|
162 |
-
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
163 |
-
f_mean_list.append(f_mean)
|
164 |
-
f_post_list.append(f_post)
|
165 |
-
f_mean = f_mean_list
|
166 |
-
f_post = f_post_list
|
167 |
-
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
|
168 |
|
169 |
|
170 |
def decoder_frame(self,zf):
|
@@ -190,7 +163,7 @@ class TransferVAE_Video(nn.Module):
|
|
190 |
|
191 |
|
192 |
def forward(self, x, beta):
|
193 |
-
|
194 |
if isinstance(f_post, list):
|
195 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
196 |
else:
|
@@ -269,15 +242,11 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
|
|
269 |
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
270 |
|
271 |
|
272 |
-
log_gpu_memory()
|
273 |
-
|
274 |
# == Load Model ==
|
275 |
model = TransferVAE_Video()
|
276 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
277 |
model.eval()
|
278 |
|
279 |
-
log_gpu_memory()
|
280 |
-
|
281 |
|
282 |
def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
|
283 |
|
@@ -338,12 +307,11 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
|
|
338 |
images_target = name2seq(file_name_target)
|
339 |
x = torch.cat((images_source, images_target), dim=0)
|
340 |
|
341 |
-
|
342 |
-
log_gpu_memory()
|
343 |
# == Forward ==
|
344 |
with torch.no_grad():
|
345 |
f_post, z_post, recon_x = model(x, [0]*3)
|
346 |
-
|
347 |
|
348 |
src_orig_sample = x[0, :, :, :, :]
|
349 |
src_recon_sample = recon_x[0, :, :, :, :]
|
@@ -389,9 +357,7 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
|
|
389 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
390 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
391 |
|
392 |
-
log_gpu_memory()
|
393 |
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
394 |
-
log_gpu_memory()
|
395 |
|
396 |
a = concat('MyPlot_')
|
397 |
|
|
|
13 |
import torch.nn.functional as F
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
16 |
class RelationModuleMultiScale(torch.nn.Module):
|
17 |
|
18 |
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
|
|
125 |
|
126 |
|
127 |
def encode_and_sample_post(self, x):
|
128 |
+
conv_x = self.encoder_frame(x)
|
|
|
|
|
|
|
|
|
129 |
lstm_out, _ = self.z_lstm(conv_x)
|
|
|
130 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
131 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
132 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
133 |
f_mean = self.f_mean(lstm_out_f)
|
134 |
f_logvar = self.f_logvar(lstm_out_f)
|
135 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
|
|
136 |
features, _ = self.z_rnn(lstm_out)
|
137 |
z_mean = self.z_mean(features)
|
138 |
z_logvar = self.z_logvar(features)
|
139 |
z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
|
140 |
+
return f_post, z_post
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
|
143 |
def decoder_frame(self,zf):
|
|
|
163 |
|
164 |
|
165 |
def forward(self, x, beta):
|
166 |
+
f_post, z_post = self.encode_and_sample_post(x)
|
167 |
if isinstance(f_post, list):
|
168 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
169 |
else:
|
|
|
242 |
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
243 |
|
244 |
|
|
|
|
|
245 |
# == Load Model ==
|
246 |
model = TransferVAE_Video()
|
247 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
248 |
model.eval()
|
249 |
|
|
|
|
|
250 |
|
251 |
def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
|
252 |
|
|
|
307 |
images_target = name2seq(file_name_target)
|
308 |
x = torch.cat((images_source, images_target), dim=0)
|
309 |
|
310 |
+
|
|
|
311 |
# == Forward ==
|
312 |
with torch.no_grad():
|
313 |
f_post, z_post, recon_x = model(x, [0]*3)
|
314 |
+
|
315 |
|
316 |
src_orig_sample = x[0, :, :, :, :]
|
317 |
src_recon_sample = recon_x[0, :, :, :, :]
|
|
|
357 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
358 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
359 |
|
|
|
360 |
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
|
|
361 |
|
362 |
a = concat('MyPlot_')
|
363 |
|