Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| import math | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None): | |
| super(MultiHeadAttention, self).__init__() | |
| if val_dim is None: | |
| assert embed_dim is not None, "Provide either embed_dim or val_dim" | |
| val_dim = embed_dim // n_heads | |
| if key_dim is None: | |
| key_dim = val_dim | |
| self.n_heads = n_heads | |
| self.input_dim = input_dim | |
| self.embed_dim = embed_dim | |
| self.val_dim = val_dim | |
| self.key_dim = key_dim | |
| self.norm_factor = 1 / math.sqrt(key_dim) | |
| self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) | |
| self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) | |
| self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) | |
| self.W_out = nn.Parameter(torch.Tensor(n_heads * val_dim, embed_dim)) | |
| self.init_parameters() | |
| def init_parameters(self): | |
| for param in self.parameters(): | |
| stdv = 1. / math.sqrt(param.size(-1)) | |
| param.data.uniform_(-stdv, stdv) | |
| def forward(self, q, h=None, mask=None): | |
| if h is None: | |
| h = q # self-attention | |
| batch_size, graph_size, input_dim = h.size() | |
| n_query = q.size(1) | |
| hflat = h.contiguous().view(-1, input_dim) | |
| qflat = q.contiguous().view(-1, input_dim) | |
| K = torch.matmul(hflat, self.W_key).view(self.n_heads, batch_size, graph_size, self.key_dim) | |
| V = torch.matmul(hflat, self.W_val).view(self.n_heads, batch_size, graph_size, self.val_dim) | |
| Q = torch.matmul(qflat, self.W_query).view(self.n_heads, batch_size, n_query, self.key_dim) | |
| # Compute attention scores | |
| compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # (n_heads, batch, n_query, graph) | |
| if mask is not None: | |
| mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility) | |
| compatibility = compatibility.masked_fill(mask, -1e9) | |
| attn = F.softmax(compatibility, dim=-1) | |
| # Apply attention to values | |
| heads = torch.matmul(attn, V) # (n_heads, batch, n_query, val_dim) | |
| # Concatenate heads and project | |
| heads = heads.permute(1, 2, 0, 3).contiguous().view(batch_size, n_query, -1) | |
| out = torch.matmul(heads, self.W_out) # (batch, n_query, embed_dim) | |
| return out | |