File size: 2,476 Bytes
b762e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from options import get_parser_main_model
opts = get_parser_main_model().parse_args()

def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)


class ModalityFusion(nn.Module):
    def __init__(self, img_size=64, ref_nshot=4, bottleneck_bits=512, ngf=32, seq_latent_dim=512, mode='train'):
        super().__init__()
        self.mode = mode
        self.bottleneck_bits = bottleneck_bits
        self.ref_nshot = ref_nshot
        self.mode = mode
        self.fc_merge = nn.Linear(seq_latent_dim * opts.ref_nshot, 512)
        n_downsampling = int(math.log(img_size, 2))
        mult_max = 2 ** (n_downsampling)
        self.fc_fusion = nn.Linear(ngf * mult_max + seq_latent_dim, opts.bottleneck_bits * 2, bias=True) # the max multiplier for img feat channels is 

    def forward(self, seq_feat, img_feat, ref_pad_mask=None):


        cls_one_pad = torch.ones((1,1,1)).to(seq_feat.device).repeat(seq_feat.size(0),1,1)
        ref_pad_mask = torch.cat([cls_one_pad,ref_pad_mask],dim=-1)
        
        seq_feat = seq_feat * (ref_pad_mask.transpose(1, 2))
        seq_feat_ = seq_feat.view(seq_feat.size(0) // self.ref_nshot, self.ref_nshot,seq_feat.size(-2) , seq_feat.size(-1))
        seq_feat_ = seq_feat_.transpose(1, 2)
        seq_feat_ = seq_feat_.contiguous().view(seq_feat_.size(0), seq_feat_.size(1), seq_feat_.size(2) * seq_feat_.size(3))
        seq_feat_ = self.fc_merge(seq_feat_)
        seq_feat_cls = seq_feat_[:, 0]

        feat_cat = torch.cat((img_feat, seq_feat_cls),-1)
        dist_param = self.fc_fusion(feat_cat)

        output = {}
        mu = dist_param[..., :self.bottleneck_bits]
        log_sigma = dist_param[..., self.bottleneck_bits:]

        if self.mode == 'train':
            # calculate the kl loss and reparamerize latent code
            epsilon = torch.randn(*mu.size(), device=mu.device)
            z = mu + torch.exp(log_sigma / 2) * epsilon
            kl = 0.5 * torch.mean(torch.exp(log_sigma) + torch.square(mu) - 1. - log_sigma)
            output['latent'] = z
            output['kl_loss'] = kl
            seq_feat_[:, 0] = z
            latent_feat_seq = seq_feat_

        else:
            output['latent'] = mu
            output['kl_loss'] = 0.0
            seq_feat_[:, 0] = mu
            latent_feat_seq = seq_feat_

        
        return output, latent_feat_seq