Spaces:
Sleeping
Sleeping
File size: 2,476 Bytes
b762e56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from options import get_parser_main_model
opts = get_parser_main_model().parse_args()
def init_weights(m):
for name, param in m.named_parameters():
nn.init.uniform_(param.data, -0.08, 0.08)
class ModalityFusion(nn.Module):
def __init__(self, img_size=64, ref_nshot=4, bottleneck_bits=512, ngf=32, seq_latent_dim=512, mode='train'):
super().__init__()
self.mode = mode
self.bottleneck_bits = bottleneck_bits
self.ref_nshot = ref_nshot
self.mode = mode
self.fc_merge = nn.Linear(seq_latent_dim * opts.ref_nshot, 512)
n_downsampling = int(math.log(img_size, 2))
mult_max = 2 ** (n_downsampling)
self.fc_fusion = nn.Linear(ngf * mult_max + seq_latent_dim, opts.bottleneck_bits * 2, bias=True) # the max multiplier for img feat channels is
def forward(self, seq_feat, img_feat, ref_pad_mask=None):
cls_one_pad = torch.ones((1,1,1)).to(seq_feat.device).repeat(seq_feat.size(0),1,1)
ref_pad_mask = torch.cat([cls_one_pad,ref_pad_mask],dim=-1)
seq_feat = seq_feat * (ref_pad_mask.transpose(1, 2))
seq_feat_ = seq_feat.view(seq_feat.size(0) // self.ref_nshot, self.ref_nshot,seq_feat.size(-2) , seq_feat.size(-1))
seq_feat_ = seq_feat_.transpose(1, 2)
seq_feat_ = seq_feat_.contiguous().view(seq_feat_.size(0), seq_feat_.size(1), seq_feat_.size(2) * seq_feat_.size(3))
seq_feat_ = self.fc_merge(seq_feat_)
seq_feat_cls = seq_feat_[:, 0]
feat_cat = torch.cat((img_feat, seq_feat_cls),-1)
dist_param = self.fc_fusion(feat_cat)
output = {}
mu = dist_param[..., :self.bottleneck_bits]
log_sigma = dist_param[..., self.bottleneck_bits:]
if self.mode == 'train':
# calculate the kl loss and reparamerize latent code
epsilon = torch.randn(*mu.size(), device=mu.device)
z = mu + torch.exp(log_sigma / 2) * epsilon
kl = 0.5 * torch.mean(torch.exp(log_sigma) + torch.square(mu) - 1. - log_sigma)
output['latent'] = z
output['kl_loss'] = kl
seq_feat_[:, 0] = z
latent_feat_seq = seq_feat_
else:
output['latent'] = mu
output['kl_loss'] = 0.0
seq_feat_[:, 0] = mu
latent_feat_seq = seq_feat_
return output, latent_feat_seq
|