Update app.py
Browse files
app.py
CHANGED
@@ -59,18 +59,6 @@ class RelationModuleMultiScale(torch.nn.Module):
|
|
59 |
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
|
60 |
|
61 |
|
62 |
-
class GradReverse(Function):
|
63 |
-
@staticmethod
|
64 |
-
def forward(ctx, x, beta):
|
65 |
-
ctx.beta = beta
|
66 |
-
return x.view_as(x)
|
67 |
-
|
68 |
-
@staticmethod
|
69 |
-
def backward(ctx, grad_output):
|
70 |
-
grad_input = grad_output.neg() * ctx.beta
|
71 |
-
return grad_input, None
|
72 |
-
|
73 |
-
|
74 |
class TransferVAE_Video(nn.Module):
|
75 |
|
76 |
def __init__(self):
|
@@ -133,86 +121,18 @@ class TransferVAE_Video(nn.Module):
|
|
133 |
self.relation_domain_classifier_all += [relation_domain_classifier]
|
134 |
|
135 |
self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
|
136 |
-
|
137 |
self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
|
138 |
self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
|
139 |
|
140 |
-
|
141 |
-
def domain_classifier_frame(self, feat, beta):
|
142 |
-
feat_fc_domain_frame = GradReverse.apply(feat, beta)
|
143 |
-
feat_fc_domain_frame = self.fc_feature_domain_frame(feat_fc_domain_frame)
|
144 |
-
feat_fc_domain_frame = self.relu(feat_fc_domain_frame)
|
145 |
-
pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame)
|
146 |
-
return pred_fc_domain_frame
|
147 |
-
|
148 |
-
|
149 |
-
def domain_classifier_video(self, feat_video, beta):
|
150 |
-
feat_fc_domain_video = GradReverse.apply(feat_video, beta)
|
151 |
-
feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video)
|
152 |
-
feat_fc_domain_video = self.relu(feat_fc_domain_video)
|
153 |
-
pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video)
|
154 |
-
return pred_fc_domain_video
|
155 |
-
|
156 |
-
|
157 |
-
def domain_classifier_latent(self, f):
|
158 |
-
feat_fc_domain_latent = self.fc_feature_domain_latent(f)
|
159 |
-
feat_fc_domain_latent = self.relu(feat_fc_domain_latent)
|
160 |
-
pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent)
|
161 |
-
return pred_fc_domain_latent
|
162 |
-
|
163 |
-
|
164 |
-
def domain_classifier_relation(self, feat_relation, beta):
|
165 |
-
pred_fc_domain_relation_video = None
|
166 |
-
for i in range(len(self.relation_domain_classifier_all)):
|
167 |
-
feat_relation_single = feat_relation[:,i,:].squeeze(1)
|
168 |
-
feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta)
|
169 |
-
|
170 |
-
pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single)
|
171 |
-
|
172 |
-
if pred_fc_domain_relation_video is None:
|
173 |
-
pred_fc_domain_relation_video = pred_fc_domain_relation_single.view(-1,1,2)
|
174 |
-
else:
|
175 |
-
pred_fc_domain_relation_video = torch.cat((pred_fc_domain_relation_video, pred_fc_domain_relation_single.view(-1,1,2)), 1)
|
176 |
-
|
177 |
-
pred_fc_domain_relation_video = pred_fc_domain_relation_video.view(-1,2)
|
178 |
-
|
179 |
-
return pred_fc_domain_relation_video
|
180 |
-
|
181 |
-
|
182 |
-
def get_trans_attn(self, pred_domain):
|
183 |
-
softmax = nn.Softmax(dim=1)
|
184 |
-
logsoftmax = nn.LogSoftmax(dim=1)
|
185 |
-
entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1)
|
186 |
-
weights = 1 - entropy
|
187 |
-
return weights
|
188 |
-
|
189 |
-
|
190 |
-
def get_general_attn(self, feat):
|
191 |
-
num_segments = feat.size()[1]
|
192 |
-
feat = feat.view(-1, feat.size()[-1]) # reshape features: 128x4x256 --> (128x4)x256
|
193 |
-
weights = self.attn_layer(feat) # e.g. (128x4)x1
|
194 |
-
weights = weights.view(-1, num_segments, weights.size()[-1]) # reshape attention weights: (128x4)x1 --> 128x4x1
|
195 |
-
weights = F.softmax(weights, dim=1) # softmax over segments ==> 128x4x1
|
196 |
-
return weights
|
197 |
-
|
198 |
-
|
199 |
-
def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments):
|
200 |
-
weights_attn = self.get_trans_attn(pred_domain)
|
201 |
-
weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) # reshape & repeat weights (e.g. 16 x 4 x 256)
|
202 |
-
feat_fc_attn = (weights_attn+1) * feat_fc
|
203 |
-
return feat_fc_attn, weights_attn[:,:,0]
|
204 |
-
|
205 |
|
206 |
def encode_and_sample_post(self, x):
|
207 |
if isinstance(x, list):
|
208 |
conv_x = self.encoder_frame(x[0])
|
209 |
else:
|
210 |
conv_x = self.encoder_frame(x)
|
211 |
-
|
212 |
-
# pass the bidirectional lstm
|
213 |
lstm_out, _ = self.z_lstm(conv_x)
|
214 |
-
|
215 |
-
# get f:
|
216 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
217 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
218 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
@@ -220,7 +140,6 @@ class TransferVAE_Video(nn.Module):
|
|
220 |
f_logvar = self.f_logvar(lstm_out_f)
|
221 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
222 |
|
223 |
-
# pass to one direction rnn
|
224 |
features, _ = self.z_rnn(lstm_out)
|
225 |
z_mean = self.z_mean(features)
|
226 |
z_logvar = self.z_logvar(features)
|
@@ -232,7 +151,6 @@ class TransferVAE_Video(nn.Module):
|
|
232 |
for t in range(1,3,1):
|
233 |
conv_x = self.encoder_frame(x[t])
|
234 |
lstm_out, _ = self.z_lstm(conv_x)
|
235 |
-
# get f:
|
236 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
237 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
238 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
@@ -243,7 +161,6 @@ class TransferVAE_Video(nn.Module):
|
|
243 |
f_post_list.append(f_post)
|
244 |
f_mean = f_mean_list
|
245 |
f_post = f_post_list
|
246 |
-
# f_mean and f_post are list if triple else not
|
247 |
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
|
248 |
|
249 |
|
@@ -260,7 +177,6 @@ class TransferVAE_Video(nn.Module):
|
|
260 |
|
261 |
|
262 |
def reparameterize(self, mean, logvar, random_sampling=True):
|
263 |
-
# Reparametrization occurs only if random sampling is set to true, otherwise mean is returned
|
264 |
if random_sampling is True:
|
265 |
eps = torch.randn_like(logvar)
|
266 |
std = torch.exp(0.5 * logvar)
|
@@ -269,88 +185,20 @@ class TransferVAE_Video(nn.Module):
|
|
269 |
else:
|
270 |
return mean
|
271 |
|
272 |
-
def sample_z_prior_train(self, z_post, random_sampling=True):
|
273 |
-
z_out = None
|
274 |
-
z_means = None
|
275 |
-
z_logvars = None
|
276 |
-
batch_size = z_post.shape[0]
|
277 |
-
|
278 |
-
z_t = torch.zeros(batch_size, self.z_dim).cpu()
|
279 |
-
h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
280 |
-
c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
281 |
-
h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
282 |
-
c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
283 |
-
|
284 |
-
for i in range(self.frames):
|
285 |
-
# two layer LSTM and two one-layer FC
|
286 |
-
h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
|
287 |
-
h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
|
288 |
-
|
289 |
-
z_mean_t = self.z_prior_mean(h_t_ly2)
|
290 |
-
z_logvar_t = self.z_prior_logvar(h_t_ly2)
|
291 |
-
z_prior = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
|
292 |
-
if z_out is None:
|
293 |
-
# If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
|
294 |
-
z_out = z_prior.unsqueeze(1)
|
295 |
-
z_means = z_mean_t.unsqueeze(1)
|
296 |
-
z_logvars = z_logvar_t.unsqueeze(1)
|
297 |
-
else:
|
298 |
-
# If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
|
299 |
-
z_out = torch.cat((z_out, z_prior.unsqueeze(1)), dim=1)
|
300 |
-
z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
|
301 |
-
z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
|
302 |
-
z_t = z_post[:,i,:]
|
303 |
-
return z_means, z_logvars, z_out
|
304 |
-
|
305 |
-
# If random sampling is true, reparametrization occurs else z_t is just set to the mean
|
306 |
-
def sample_z(self, batch_size, random_sampling=True):
|
307 |
-
z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim]
|
308 |
-
z_means = None
|
309 |
-
z_logvars = None
|
310 |
-
|
311 |
-
# All states are initially set to 0, especially z_0 = 0
|
312 |
-
z_t = torch.zeros(batch_size, self.z_dim).cpu()
|
313 |
-
# z_mean_t = torch.zeros(batch_size, self.z_dim)
|
314 |
-
# z_logvar_t = torch.zeros(batch_size, self.z_dim)
|
315 |
-
h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
316 |
-
c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
317 |
-
h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
318 |
-
c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
319 |
-
for _ in range(self.frames):
|
320 |
-
# h_t, c_t = self.z_prior_lstm(z_t, (h_t, c_t))
|
321 |
-
# two layer LSTM and two one-layer FC
|
322 |
-
h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
|
323 |
-
h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
|
324 |
-
|
325 |
-
z_mean_t = self.z_prior_mean(h_t_ly2)
|
326 |
-
z_logvar_t = self.z_prior_logvar(h_t_ly2)
|
327 |
-
z_t = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
|
328 |
-
if z_out is None:
|
329 |
-
# If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
|
330 |
-
z_out = z_t.unsqueeze(1)
|
331 |
-
z_means = z_mean_t.unsqueeze(1)
|
332 |
-
z_logvars = z_logvar_t.unsqueeze(1)
|
333 |
-
else:
|
334 |
-
# If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
|
335 |
-
z_out = torch.cat((z_out, z_t.unsqueeze(1)), dim=1)
|
336 |
-
z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
|
337 |
-
z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
|
338 |
-
return z_means, z_logvars, z_out
|
339 |
|
340 |
def forward(self, x, beta):
|
341 |
_, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
|
342 |
-
|
343 |
if isinstance(f_post, list):
|
344 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
345 |
else:
|
346 |
f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
347 |
zf = torch.cat((z_post, f_expand), dim=2)
|
348 |
-
|
349 |
recon_x = self.decoder_frame(zf)
|
350 |
-
|
351 |
return f_post, z_post, recon_x
|
352 |
|
353 |
|
|
|
|
|
354 |
def name2seq(file_name):
|
355 |
images = []
|
356 |
|
@@ -520,7 +368,7 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
|
|
520 |
|
521 |
# == Forward ==
|
522 |
with torch.no_grad():
|
523 |
-
|
524 |
|
525 |
src_orig_sample = x[0, :, :, :, :]
|
526 |
src_recon_sample = recon_x[0, :, :, :, :]
|
|
|
59 |
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
class TransferVAE_Video(nn.Module):
|
63 |
|
64 |
def __init__(self):
|
|
|
121 |
self.relation_domain_classifier_all += [relation_domain_classifier]
|
122 |
|
123 |
self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
|
|
|
124 |
self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
|
125 |
self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def encode_and_sample_post(self, x):
|
129 |
if isinstance(x, list):
|
130 |
conv_x = self.encoder_frame(x[0])
|
131 |
else:
|
132 |
conv_x = self.encoder_frame(x)
|
133 |
+
|
|
|
134 |
lstm_out, _ = self.z_lstm(conv_x)
|
135 |
+
|
|
|
136 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
137 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
138 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
|
140 |
f_logvar = self.f_logvar(lstm_out_f)
|
141 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
142 |
|
|
|
143 |
features, _ = self.z_rnn(lstm_out)
|
144 |
z_mean = self.z_mean(features)
|
145 |
z_logvar = self.z_logvar(features)
|
|
|
151 |
for t in range(1,3,1):
|
152 |
conv_x = self.encoder_frame(x[t])
|
153 |
lstm_out, _ = self.z_lstm(conv_x)
|
|
|
154 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
155 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
156 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
|
161 |
f_post_list.append(f_post)
|
162 |
f_mean = f_mean_list
|
163 |
f_post = f_post_list
|
|
|
164 |
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
|
165 |
|
166 |
|
|
|
177 |
|
178 |
|
179 |
def reparameterize(self, mean, logvar, random_sampling=True):
|
|
|
180 |
if random_sampling is True:
|
181 |
eps = torch.randn_like(logvar)
|
182 |
std = torch.exp(0.5 * logvar)
|
|
|
185 |
else:
|
186 |
return mean
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
def forward(self, x, beta):
|
190 |
_, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
|
|
|
191 |
if isinstance(f_post, list):
|
192 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
193 |
else:
|
194 |
f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
195 |
zf = torch.cat((z_post, f_expand), dim=2)
|
|
|
196 |
recon_x = self.decoder_frame(zf)
|
|
|
197 |
return f_post, z_post, recon_x
|
198 |
|
199 |
|
200 |
+
|
201 |
+
|
202 |
def name2seq(file_name):
|
203 |
images = []
|
204 |
|
|
|
368 |
|
369 |
# == Forward ==
|
370 |
with torch.no_grad():
|
371 |
+
f_post, z_post, recon_x = model(x, [0]*3)
|
372 |
|
373 |
src_orig_sample = x[0, :, :, :, :]
|
374 |
src_recon_sample = recon_x[0, :, :, :, :]
|