#!/usr/bin/env python # coding=utf-8 # Copyright 2022 Vladislav Lialin and Namrata Shivagunde # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, input_size, hidden, num_heads, causal=False): """Multi-head attention module which computes [softmax(xQ_h @ xK_h^T) @ xV: ...] @ U Can work as both self-attention or cross-attention (if kv is provided to .forward). Args: causal: use causal masking (do not allow target to look to the future or current token of source) """ if hidden % num_heads: raise ValueError(f"hidden should be divisible by num_heads, " f"but got hidden={hidden} and num_heads={num_heads}") super().__init__() self.k = nn.Linear(input_size, hidden) self.q = nn.Linear(input_size, hidden) self.v = nn.Linear(input_size, hidden) self.mix = nn.Linear(hidden, hidden) self.num_heads = num_heads self.head_size = hidden // num_heads self.scale = self.head_size ** 0.5 self.causal = causal # causal masking def forward(self, q, kv=None, key_padding_mask=None, return_attention=False): """[Softmax(source Q_1 @ target K_1^T) @ target V_1 : ... ) @ x V_heads] @ U Performs self-attention if kv is not specified. In this case, kv = q and kv_seq_len = query_seq_len. Args: q: FloatTensor[batch_size, query_seq_len, input_size] kv (target) : optional, FloatTensor[batch_size, kv_seq_len, input_size] key_padding_mask: BoolTensor[batch_size, kv_seq_len] 0 means unpadded, 1 means padded Returns: FloatTensor[batch_size, seq_len, hidden] """ # Task 1.1 (1 point) # Update this function with cross-attention mechanism # If target is None, then target (kv) and source (q) will be same. # Define k, q, v using self.k, self.q and self.v based on if the target exists or not # Note : Please write shape of each tensor for each line of code ## YOUR CODE STARTS HERE## ~ 2 lines code k = self.k(kv) if kv!=None else self.k(q) # print('k', k.shape, 'q', q.shape) q = self.q(q) v = self.v(kv) if kv!=None else self.v(q) # print("KV", kv) # YOUR CODE ENDS HERE bs, attending_seq, _ = q.shape attended_seq = k.shape[1] # [b, s, h] -> [b, h, s] -> [b * heads, h / heads, s] -> [b * heads, s, h / heads] k = k.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() # [batch * num_heads, seq, hidden / num_heads] q = q.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() v = v.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() scores = q @ k.transpose(1, 2) / self.scale # [batch * num_heads, attending_seq, attended_seq] assert scores.shape == (bs * self.num_heads, attending_seq, attended_seq) if key_padding_mask is not None: # Task 1.2 (1 point) # Padding # Set the scores corresponding to padded positions (key_padding_mask == 1) to -inf # # You might need to reshape the scores to [batch_size, seq_len, seq_len] # in this case, remember to reshape them back # Our implementation is 3 lines # YOUR CODE STARTS HERE # print(scores.shape, key_padding_mask.unsqueeze(-2).shape) scores = scores.reshape(self.num_heads, bs, attending_seq, attended_seq) scores_check = scores.reshape(bs, self.num_heads, attending_seq, -1) # print("Socres:", scores.shape, "Scores_Check:", scores_check.shape) # print('----') scores = scores.masked_fill(key_padding_mask.unsqueeze(-2)==1, value = float("-inf")) scores = scores.view(bs * self.num_heads, attending_seq, attended_seq) # YOUR CODE ENDS HERE assert scores.size() == (bs * self.num_heads, attending_seq, attended_seq),\ f"scores have wrong shape. Expected {(bs * self.num_heads, attending_seq, attended_seq)}, got {scores.size()}" if self.causal: causal_mask = torch.triu(torch.ones(attending_seq, attended_seq, dtype=torch.bool, device=scores.device), diagonal=1) scores.masked_fill_(causal_mask.bool().unsqueeze(0), float("-inf")) probs = torch.softmax(scores, dim=-1) # [batch * num_heads, tgt_seq, src_seq] att = probs @ v # [batch * num_heads, tgt_seq, hidden / num_heads] # [b * heads, s, h / heads] -> [b * heads, h / heads, s] -> [b, h, s] -> [b, s, h] att = att.transpose(1, 2).reshape(bs, -1, attending_seq).transpose(1, 2).contiguous() att = self.mix(att) if return_attention: return att, probs return att