Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Scalable, Detailed and Mask-free Universal Photometric Stereo Network (CVPR2023) | |
# Copyright (c) 2023 Satoshi Ikehata | |
# All rights reserved. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
import math | |
import torch.nn.init as init | |
class MultiHeadAttentionBlock(nn.Module): | |
def __init__(self, dim_in, dim_out, num_heads, ln=False, attention_dropout = 0.1, dim_feedforward = 512, | |
q_bucket_size = 1024, k_bucket_size = 2048, attn_mode = 'Normal'): | |
super(MultiHeadAttentionBlock, self).__init__() | |
if attn_mode == 'Efficient': | |
self.q_bucket_size = q_bucket_size | |
self.k_bucket_size = k_bucket_size | |
self.attn_mode = attn_mode | |
self.dim_V = dim_out | |
self.dim_Q = dim_in | |
self.dim_K = dim_in | |
self.num_heads = num_heads | |
self.fc_q = nn.Linear(self.dim_Q, self.dim_V, bias=False) # dimin -> dimhidden | |
self.fc_k = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimin -> dimhidden | |
self.fc_v = nn.Linear(self.dim_K, self.dim_V, bias=False) # dimhidden -> dim | |
if ln: | |
self.ln0 = nn.LayerNorm(self.dim_Q) | |
self.ln1 = nn.LayerNorm(self.dim_V) | |
self.dropout_attn = nn.Dropout(attention_dropout) | |
self.fc_o1 = nn.Linear(self.dim_V, dim_feedforward, bias=False) | |
self.fc_o2 = nn.Linear(dim_feedforward, self.dim_V, bias=False) | |
self.dropout1 = nn.Dropout(attention_dropout) | |
self.dropout2 = nn.Dropout(attention_dropout) | |
# memory efficient attention related parameters | |
# can be overriden on forward | |
self.q_bucket_size = q_bucket_size | |
self.k_bucket_size = k_bucket_size | |
# memory efficient attention | |
def summarize_qkv_chunk(self, q, k, v): | |
weight = torch.einsum('b h i d, b h j d -> b h i j', q, k) | |
weight_max = weight.amax(dim = -1, keepdim = True).detach() | |
weight = weight - weight_max | |
exp_weight = self.dropout_attn(weight.exp()) # attention_dropout | |
weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v) | |
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') | |
def memory_efficient_attention( | |
self, | |
q, k, v, | |
q_bucket_size = 512, | |
k_bucket_size = 1024, | |
eps = 1e-8, | |
): | |
scale = q.shape[-1] ** -0.5 | |
q = q * scale | |
summarize_qkv_fn = self.summarize_qkv_chunk | |
# chunk all the inputs | |
q_chunks = q.split(q_bucket_size, dim = -2) | |
k_chunks = k.split(k_bucket_size, dim = -2) | |
v_chunks = v.split(k_bucket_size, dim = -2) | |
# loop through all chunks and accumulate | |
values = [] | |
weights = [] | |
for q_chunk in q_chunks: | |
exp_weights = [] | |
weighted_values = [] | |
weight_maxes = [] | |
for (k_chunk, v_chunk) in zip(k_chunks, v_chunks): | |
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( | |
q_chunk, | |
k_chunk, | |
v_chunk | |
) | |
exp_weights.append(exp_weight_chunk) | |
weighted_values.append(weighted_value_chunk) | |
weight_maxes.append(weight_max_chunk) | |
weight_maxes = torch.stack(weight_maxes, dim = -1) | |
weighted_values = torch.stack(weighted_values, dim = -1) | |
exp_weights = torch.stack(exp_weights, dim = -1) | |
global_max = weight_maxes.amax(dim = -1, keepdim = True) | |
renorm_factor = (weight_maxes - global_max).exp().detach() | |
exp_weights = exp_weights * renorm_factor | |
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') | |
all_values = weighted_values.sum(dim = -1) | |
all_weights = exp_weights.sum(dim = -1) | |
values.append(all_values) | |
weights.append(all_weights) | |
values = torch.cat(values, dim=2) | |
weights = torch.cat(weights, dim=2) | |
# (rearrange(weights, '... -> ... 1') | |
normalized_values = values / (rearrange(weights, '... -> ... 1') + eps) | |
return normalized_values | |
def forward( | |
self, | |
x,y, | |
): | |
x = x if getattr(self, 'ln0', None) is None else self.ln0(x) # pre-normalization | |
Q = self.fc_q(x) # input_dim -> embed dim | |
K, V = self.fc_k(y), self.fc_v(y) # input_dim -> embed dim | |
dim_split = self.dim_V // self.num_heads # multi-head attention | |
if self.attn_mode == 'Efficient': | |
q_bucket_size = self.q_bucket_size | |
k_bucket_size = self.k_bucket_size | |
Q_ = torch.stack(Q.split(int(dim_split), 2), 1) | |
K_ = torch.stack(K.split(int(dim_split), 2), 1) | |
V_ = torch.stack(V.split(int(dim_split), 2), 1) | |
A = self.memory_efficient_attention(Q_, K_, V_, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) | |
A = A.reshape(-1, A.shape[2], A.shape[3]) | |
Q_ = Q_.reshape(-1, Q_.shape[2], Q_.shape[3]) | |
O = torch.cat((Q_ + A).split(Q.size(0), 0), 2) | |
else: # Basic | |
Q_ = torch.cat(Q.split(int(dim_split), 2), 0) | |
K_ = torch.cat(K.split(int(dim_split), 2), 0) | |
V_ = torch.cat(V.split(int(dim_split), 2), 0) | |
A = self.dropout_attn(torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)) # this may not be correct due to mult-head attention | |
A = A.bmm(V_) # A(Q, K, V) attention_output | |
O = torch.cat((Q_ + A).split(Q.size(0), 0), 2) | |
O_ = O if getattr(self, 'ln1', None) is None else self.ln1(O) | |
O = O + self.dropout2(self.fc_o2(self.dropout1(F.gelu(self.fc_o1(O_))))) | |
return O | |
class SAB(nn.Module): # self attention block | |
def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'): | |
super(SAB, self).__init__() | |
self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode) | |
def forward(self, X): | |
return self.mab(X, X) | |
class CAB(nn.Module): # cross attention block | |
def __init__(self, dim_in, dim_out, num_heads=4, ln=False, attention_dropout = 0.1, dim_feedforward = 512, attn_mode = 'Normal'): | |
super(CAB, self).__init__() | |
self.mab = MultiHeadAttentionBlock(dim_in, dim_out, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode) | |
def forward(self, q, kv): | |
return self.mab(q, kv) | |
class PMA(nn.Module): | |
def __init__(self, dim, num_heads, num_seeds, ln=False, attn_mode='Normal'): | |
super(PMA, self).__init__() | |
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) | |
init.xavier_uniform_(self.S) | |
self.mab = MultiHeadAttentionBlock(dim, dim, num_heads, ln=ln, attn_mode=attn_mode) | |
def forward(self, X): | |
return self.mab(self.S.repeat(X.size(0), 1, 1), X) | |
class CommunicationBlock(nn.Module): | |
def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False): | |
super(CommunicationBlock, self).__init__() | |
if use_efficient_attention: | |
attn_mode = 'Efficient' | |
else: | |
attn_mode = 'Normal' | |
self.dim_hidden = dim_hidden | |
modules_enc = [] | |
modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) | |
for k in range(num_enc_sab): | |
modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) | |
self.enc = nn.Sequential(*modules_enc) | |
def forward(self, x): | |
x = self.enc(x) | |
return x | |
class CrossAttentionBlock(nn.Module): | |
def __init__(self, dim_input, num_enc_sab = 3, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False): | |
super(CrossAttentionBlock, self).__init__() | |
if use_efficient_attention: | |
attn_mode = 'Efficient' | |
else: | |
attn_mode = 'Normal' | |
self.dim_hidden = dim_hidden | |
modules_enc = [] | |
modules_enc.append(CAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) | |
for k in range(num_enc_sab): | |
modules_enc.append(CAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) | |
self.layers = nn.ModuleList(modules_enc) | |
def forward(self, q, kv): | |
for k in range(len(self.layers)): | |
q = self.layers[k](q, kv) # x[0] query, x[1] key value | |
return q # the output token length is len(x) | |
class AggregationBlock(nn.Module): | |
def __init__(self, dim_input, num_enc_sab = 3, num_outputs = 1, dim_hidden=384, dim_feedforward = 1024, num_heads=8, ln=False, attention_dropout=0.1, use_efficient_attention=False): | |
super(AggregationBlock, self).__init__() | |
self.num_outputs = num_outputs | |
self.dim_hidden = dim_hidden | |
if use_efficient_attention: | |
attn_mode = 'Efficient' | |
else: | |
attn_mode = 'Normal' | |
modules_enc = [] | |
modules_enc.append(SAB(dim_input, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) | |
for k in range(num_enc_sab): | |
modules_enc.append(SAB(dim_hidden, dim_hidden, num_heads, ln=ln, attention_dropout = attention_dropout, dim_feedforward=dim_feedforward, attn_mode=attn_mode)) | |
self.enc = nn.Sequential(*modules_enc) | |
modules_dec = [] | |
modules_dec.append(PMA(dim_hidden, num_heads, num_outputs, attn_mode=attn_mode)) # after the PMA we should not put drop out | |
self.dec = nn.Sequential(*modules_dec) | |
def forward(self, x): | |
x = self.enc(x) | |
x = self.dec(x) | |
x = x.view(-1, self.num_outputs * self.dim_hidden) | |
return x | |