File size: 10,725 Bytes
c295391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""

Scalable, Detailed and Mask-free Universal Photometric Stereo Network (CVPR2023)

# Copyright (c) 2023 Satoshi Ikehata

# All rights reserved.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
import torch.nn.init as init

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False, attention_dropout = 0.1, dim_feedforward = 512,  

                 q_bucket_size = 1024, k_bucket_size = 2048, attn_mode = 'Normal'):
        super(MultiHeadAttentionBlock, self).__init__()
        if attn_mode == 'Efficient':
            self.q_bucket_size = q_bucket_size
            self.k_bucket_size = k_bucket_size
        self.attn_mode = attn_mode
        self.dim_V = dim_out
        self.dim_Q = dim_in
        self.dim_K = dim_in
        self.num_heads = num_heads
        self.fc_q = nn.Linear(self.dim_Q, self.dim_V, bias=False) # dimin -> dimhidden
        self.fc_k = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimin -> dimhidden
        self.fc_v = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimhidden -> dim
        if ln:
            self.ln0 = nn.LayerNorm(self.dim_Q)
            self.ln1 = nn.LayerNorm(self.dim_V)

        self.dropout_attn = nn.Dropout(attention_dropout)
        self.fc_o1 = nn.Linear(self.dim_V, dim_feedforward, bias=False)
        self.fc_o2 = nn.Linear(dim_feedforward, self.dim_V, bias=False)
        self.dropout1 = nn.Dropout(attention_dropout)
        self.dropout2 = nn.Dropout(attention_dropout)

        # memory efficient attention related parameters
        # can be overriden on forward
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size
     

    # memory efficient attention
    def summarize_qkv_chunk(self, q, k, v):
        weight = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        weight_max = weight.amax(dim = -1, keepdim = True).detach()
        weight = weight - weight_max

        exp_weight = self.dropout_attn(weight.exp()) # attention_dropout
        weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v)
        return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')

    def memory_efficient_attention(

        self,

        q, k, v,

        q_bucket_size = 512,

        k_bucket_size = 1024,

        eps = 1e-8,

    ):
        scale = q.shape[-1] ** -0.5
        q = q * scale

        summarize_qkv_fn = self.summarize_qkv_chunk

        # chunk all the inputs
        q_chunks = q.split(q_bucket_size, dim = -2)
        k_chunks = k.split(k_bucket_size, dim = -2)
        v_chunks = v.split(k_bucket_size, dim = -2)

        # loop through all chunks and accumulate
        values = []
        weights = []
        for q_chunk in q_chunks:
            exp_weights = []
            weighted_values = []
            weight_maxes = []  
            for (k_chunk, v_chunk) in zip(k_chunks, v_chunks):
                exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
                    q_chunk,
                    k_chunk,
                    v_chunk
                )
                exp_weights.append(exp_weight_chunk)
                weighted_values.append(weighted_value_chunk)
                weight_maxes.append(weight_max_chunk)
            weight_maxes = torch.stack(weight_maxes, dim = -1)
            weighted_values = torch.stack(weighted_values, dim = -1)
            exp_weights = torch.stack(exp_weights, dim = -1)

            global_max = weight_maxes.amax(dim = -1, keepdim = True)
            renorm_factor = (weight_maxes - global_max).exp().detach()

            exp_weights = exp_weights * renorm_factor
            weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')         

            all_values = weighted_values.sum(dim = -1)
            all_weights = exp_weights.sum(dim = -1)
            values.append(all_values)
            weights.append(all_weights)    
        values = torch.cat(values, dim=2)
        weights = torch.cat(weights, dim=2)
        # (rearrange(weights, '... -> ... 1') 
        normalized_values = values / (rearrange(weights, '... -> ... 1')  + eps)
        return normalized_values


    def forward(

        self,

        x,y,

    ):
        
        x = x if getattr(self, 'ln0', None) is None else self.ln0(x) # pre-normalization
        Q = self.fc_q(x) # input_dim -> embed dim       
        K, V = self.fc_k(y), self.fc_v(y) # input_dim -> embed dim
        dim_split = self.dim_V // self.num_heads # multi-head attention
        if self.attn_mode == 'Efficient':
            q_bucket_size = self.q_bucket_size
            k_bucket_size = self.k_bucket_size
            
            Q_ = torch.stack(Q.split(int(dim_split), 2), 1)
            K_ = torch.stack(K.split(int(dim_split), 2), 1)
            V_ = torch.stack(V.split(int(dim_split), 2), 1) 
            
            A = self.memory_efficient_attention(Q_, K_, V_, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
            A = A.reshape(-1, A.shape[2], A.shape[3])
            Q_ = Q_.reshape(-1, Q_.shape[2], Q_.shape[3])
            O = torch.cat((Q_ + A).split(Q.size(0), 0), 2)
        else: # Basic
            Q_ = torch.cat(Q.split(int(dim_split), 2), 0)
            K_ = torch.cat(K.split(int(dim_split), 2), 0)
            V_ = torch.cat(V.split(int(dim_split), 2), 0)
            A = self.dropout_attn(torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)) # this may not be correct due to mult-head attention
            A =  A.bmm(V_) # A(Q, K, V) attention_output
            O = torch.cat((Q_ + A).split(Q.size(0), 0), 2)

        
        O_ = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        O = O + self.dropout2(self.fc_o2(self.dropout1(F.gelu(self.fc_o1(O_))))) 
        return O


class SAB(nn.Module): # self attention block
    def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'):
        super(SAB, self).__init__()
        self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)
    def forward(self, X):
        return self.mab(X, X)

class CAB(nn.Module): # cross attention block
    def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'):
        super(CAB, self).__init__()
        self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)
    def forward(self, q, kv):       
        return self.mab(q, kv)
    
class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False, attn_mode='Normal'):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        init.xavier_uniform_(self.S)
        self.mab = MultiHeadAttentionBlock(dim, dim, num_heads, ln=ln, attn_mode=attn_mode)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)

class CommunicationBlock(nn.Module):
    def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False):
        super(CommunicationBlock, self).__init__()

        if use_efficient_attention:
            attn_mode = 'Efficient'
        else:
            attn_mode = 'Normal'
        self.dim_hidden = dim_hidden
        modules_enc = []
        modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        for k in range(num_enc_sab):
            modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        self.enc = nn.Sequential(*modules_enc)

    def forward(self, x):
        x = self.enc(x)
        return x

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False):
        super(CrossAttentionBlock, self).__init__()

        if use_efficient_attention:
            attn_mode = 'Efficient'
        else:
            attn_mode = 'Normal'
        self.dim_hidden = dim_hidden
        modules_enc = []
        modules_enc.append(CAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        for k in range(num_enc_sab):
            modules_enc.append(CAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        self.layers = nn.ModuleList(modules_enc)
    def forward(self, q, kv):
        for k in range(len(self.layers)):
            q = self.layers[k](q, kv) # x[0] query, x[1] key value       
        return q # the output token length is len(x)
    
class AggregationBlock(nn.Module):
    def __init__(self, dim_input, num_enc_sab = 3, num_outputs = 1, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False):
        super(AggregationBlock, self).__init__()

        self.num_outputs = num_outputs
        self.dim_hidden = dim_hidden

        if use_efficient_attention:
            attn_mode = 'Efficient'
        else:
            attn_mode = 'Normal'

        modules_enc = []
        modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        for k in range(num_enc_sab):
            modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode))
        self.enc = nn.Sequential(*modules_enc)
        modules_dec = []
        modules_dec.append(PMA(dim_hidden, num_heads, num_outputs, attn_mode=attn_mode)) # after the PMA we should not put drop out
        self.dec = nn.Sequential(*modules_dec)
        
    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        x = x.view(-1, self.num_outputs * self.dim_hidden)
        return x