from inspect import isfunction import math import torch from torch import nn, einsum import torch.nn.functional as F from .blocks import get_norm, zero_module def QKV_Attention(qkv, num_heads): """ Apply QKV attention. :param qkv: an [N x (3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x H' x T] tensor after attention. """ B, C, HW = qkv.shape if C % 3 != 0: raise ValueError('QKV shape is wrong: {}, {}, {}'.format(B, C, HW)) split_size = C // (3 * num_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1.0/math.sqrt(math.sqrt(split_size)) weight = torch.einsum('bct, bcs->bts', (q * scale).view(B * num_heads, split_size, HW), (k * scale).view(B * num_heads, split_size, HW)) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) ret = torch.einsum("bts,bcs->bct", weight, v.reshape(B * num_heads, split_size, HW)) return ret.reshape(B, -1, HW) class AttentionBlock(nn.Module): """ """ def __init__(self, in_channels, num_heads, qkv_bias=False, sr_ratio=1, linear=True): super().__init__() self.num_heads = num_heads self.norm = get_norm(in_channels, 'Group') self.qkv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels * 3, kernel_size = 1) self.proj = zero_module(nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size = 1)) def forward(self, x): b, c, *spatial = x.shape num_heads = self.num_heads x = x.reshape(b, c, -1) # B x C x HW x = self.norm(x) qkv = self.qkv(x) # b x c x HW -> B x 3C x HW h = QKV_Attention(qkv, num_heads) h = self.proj(h) return (x + h).reshape(b,c,*spatial) # additive attention, similar to ResNet? def get_model_size(model): param_size = 0 for param in model.parameters(): param_size += param.nelement() * param.element_size() buffer_size = 0 for buffer in model.buffers(): buffer_size += buffer.nelement() * buffer.element_size() size_all_mb = (param_size + buffer_size) / 1024 ** 2 print('model size: {:.3f}MB'.format(size_all_mb)) # return param_size + buffer_size return size_all_mb if __name__ == '__main__': model = AttentionBlock(in_channels=256, num_heads=8) x = torch.randn(5, 256, 32, 32, dtype=torch.float32) y = model(x) print('{}, {}'.format(x.shape, y.shape)) get_model_size(model)