File size: 6,134 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import time


eps=1e-8

def sinkhorn(M,r,c,iteration):
    p = torch.softmax(M, dim=-1)
    u = torch.ones_like(r)
    v = torch.ones_like(c)
    for _ in range(iteration):
        u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps)
        v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps)
    p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
    return p

def sink_algorithm(M,dustbin,iteration):
    M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
    M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
    r = torch.ones([M.shape[0], M.shape[1] - 1],device='cuda')
    r = torch.cat([r, torch.ones([M.shape[0], 1],device='cuda') * M.shape[1]], dim=-1)
    c = torch.ones([M.shape[0], M.shape[2] - 1],device='cuda')
    c = torch.cat([c, torch.ones([M.shape[0], 1],device='cuda') * M.shape[2]], dim=-1)
    p=sinkhorn(M,r,c,iteration)
    return p


class attention_block(nn.Module):
    def __init__(self,channels,head,type):
        assert type=='self' or type=='cross','invalid attention type'
        nn.Module.__init__(self)
        self.head=head
        self.type=type
        self.head_dim=channels//head
        self.query_filter=nn.Conv1d(channels, channels, kernel_size=1)
        self.key_filter=nn.Conv1d(channels,channels,kernel_size=1)
        self.value_filter=nn.Conv1d(channels,channels,kernel_size=1)
        self.attention_filter=nn.Sequential(nn.Conv1d(2*channels,2*channels, kernel_size=1),nn.SyncBatchNorm(2*channels), nn.ReLU(),
                                             nn.Conv1d(2*channels, channels, kernel_size=1))
        self.mh_filter=nn.Conv1d(channels, channels, kernel_size=1)

    def forward(self,fea1,fea2):
        batch_size,n,m=fea1.shape[0],fea1.shape[2],fea2.shape[2]
        query1, key1, value1 = self.query_filter(fea1).view(batch_size,self.head_dim,self.head,-1), self.key_filter(fea1).view(batch_size,self.head_dim,self.head,-1), \
                               self.value_filter(fea1).view(batch_size,self.head_dim,self.head,-1)
        query2, key2, value2 = self.query_filter(fea2).view(batch_size,self.head_dim,self.head,-1), self.key_filter(fea2).view(batch_size,self.head_dim,self.head,-1), \
                               self.value_filter(fea2).view(batch_size,self.head_dim,self.head,-1)
        if(self.type=='self'):
            score1,score2=torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query1,key1)/self.head_dim**0.5,dim=-1),\
                          torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query2,key2)/self.head_dim**0.5,dim=-1)
            add_value1, add_value2 = torch.einsum('bhnm,bdhm->bdhn', score1, value1), torch.einsum('bhnm,bdhm->bdhn',score2, value2)
        else:
            score1,score2 = torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query1, key2) / self.head_dim ** 0.5,dim=-1), \
                            torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query2, key1) / self.head_dim ** 0.5, dim=-1)
            add_value1, add_value2 =torch.einsum('bhnm,bdhm->bdhn',score1,value2),torch.einsum('bhnm,bdhm->bdhn',score2,value1)
        add_value1,add_value2=self.mh_filter(add_value1.contiguous().view(batch_size,self.head*self.head_dim,n)),self.mh_filter(add_value2.contiguous().view(batch_size,self.head*self.head_dim,m))
        fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat([fea2, add_value2], dim=1)
        fea1, fea2 = fea1+self.attention_filter(fea11), fea2+self.attention_filter(fea22)
     
        return fea1,fea2


class matcher(nn.Module):
    def __init__(self, config):
        nn.Module.__init__(self)
        self.use_score_encoding=config.use_score_encoding
        self.layer_num=config.layer_num
        self.sink_iter=config.sink_iter
        self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), 
                                              nn.SyncBatchNorm(32), nn.ReLU(),
                                              nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(),
                                              nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128), nn.ReLU(),
                                              nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256), nn.ReLU(),
                                              nn.Conv1d(256, config.net_channels, kernel_size=1))
       
        self.dustbin=nn.Parameter(torch.tensor(1,dtype=torch.float32,device='cuda'))
        self.self_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'self') for _ in range(config.layer_num)])
        self.cross_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'cross') for _ in range(config.layer_num)])
        self.final_project=nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)

    def forward(self,data,test_mode=True):
        desc1, desc2 = data['desc1'], data['desc2']
        desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1)
        desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2)   
        if test_mode:
            encode_x1,encode_x2=data['x1'],data['x2']
        else:
            encode_x1,encode_x2=data['aug_x1'], data['aug_x2']
        if not self.use_score_encoding:
            encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2]

        encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2)

        x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2)
        aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding+desc2
        for i in range(self.layer_num):
            aug_desc1,aug_desc2=self.self_attention_block[i](aug_desc1,aug_desc2)
            aug_desc1,aug_desc2=self.cross_attention_block[i](aug_desc1,aug_desc2)

        aug_desc1,aug_desc2=self.final_project(aug_desc1),self.final_project(aug_desc2)
        desc_mat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2)
        p = sink_algorithm(desc_mat, self.dustbin,self.sink_iter[0])
        return {'p':p}