da_en_translation / transformer_mt /modeling_attention.py
ftakelait
Add application files
b1c0f8d
raw
history blame
No virus
5.6 kB
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 Vladislav Lialin and Namrata Shivagunde
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, input_size, hidden, num_heads, causal=False):
"""Multi-head attention module which computes [softmax(xQ_h @ xK_h^T) @ xV: ...] @ U
Can work as both self-attention or cross-attention (if kv is provided to .forward).
Args:
causal: use causal masking (do not allow target to look to the future or current token of source)
"""
if hidden % num_heads:
raise ValueError(f"hidden should be divisible by num_heads, "
f"but got hidden={hidden} and num_heads={num_heads}")
super().__init__()
self.k = nn.Linear(input_size, hidden)
self.q = nn.Linear(input_size, hidden)
self.v = nn.Linear(input_size, hidden)
self.mix = nn.Linear(hidden, hidden)
self.num_heads = num_heads
self.head_size = hidden // num_heads
self.scale = self.head_size ** 0.5
self.causal = causal # causal masking
def forward(self, q, kv=None, key_padding_mask=None, return_attention=False):
"""[Softmax(source Q_1 @ target K_1^T) @ target V_1 : ... ) @ x V_heads] @ U
Performs self-attention if kv is not specified.
In this case, kv = q and kv_seq_len = query_seq_len.
Args:
q: FloatTensor[batch_size, query_seq_len, input_size]
kv (target) : optional, FloatTensor[batch_size, kv_seq_len, input_size]
key_padding_mask: BoolTensor[batch_size, kv_seq_len] 0 means unpadded, 1 means padded
Returns:
FloatTensor[batch_size, seq_len, hidden]
"""
# Task 1.1 (1 point)
# Update this function with cross-attention mechanism
# If target is None, then target (kv) and source (q) will be same.
# Define k, q, v using self.k, self.q and self.v based on if the target exists or not
# Note : Please write shape of each tensor for each line of code
## YOUR CODE STARTS HERE## ~ 2 lines code
k = self.k(kv) if kv!=None else self.k(q)
# print('k', k.shape, 'q', q.shape)
q = self.q(q)
v = self.v(kv) if kv!=None else self.v(q)
# print("KV", kv)
# YOUR CODE ENDS HERE
bs, attending_seq, _ = q.shape
attended_seq = k.shape[1]
# [b, s, h] -> [b, h, s] -> [b * heads, h / heads, s] -> [b * heads, s, h / heads]
k = k.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() # [batch * num_heads, seq, hidden / num_heads]
q = q.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous()
v = v.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous()
scores = q @ k.transpose(1, 2) / self.scale # [batch * num_heads, attending_seq, attended_seq]
assert scores.shape == (bs * self.num_heads, attending_seq, attended_seq)
if key_padding_mask is not None:
# Task 1.2 (1 point)
# Padding
# Set the scores corresponding to padded positions (key_padding_mask == 1) to -inf
#
# You might need to reshape the scores to [batch_size, seq_len, seq_len]
# in this case, remember to reshape them back
# Our implementation is 3 lines
# YOUR CODE STARTS HERE
# print(scores.shape, key_padding_mask.unsqueeze(-2).shape)
scores = scores.reshape(self.num_heads, bs, attending_seq, attended_seq)
scores_check = scores.reshape(bs, self.num_heads, attending_seq, -1)
# print("Socres:", scores.shape, "Scores_Check:", scores_check.shape)
# print('----')
scores = scores.masked_fill(key_padding_mask.unsqueeze(-2)==1, value = float("-inf"))
scores = scores.view(bs * self.num_heads, attending_seq, attended_seq)
# YOUR CODE ENDS HERE
assert scores.size() == (bs * self.num_heads, attending_seq, attended_seq),\
f"scores have wrong shape. Expected {(bs * self.num_heads, attending_seq, attended_seq)}, got {scores.size()}"
if self.causal:
causal_mask = torch.triu(torch.ones(attending_seq, attended_seq, dtype=torch.bool, device=scores.device), diagonal=1)
scores.masked_fill_(causal_mask.bool().unsqueeze(0), float("-inf"))
probs = torch.softmax(scores, dim=-1) # [batch * num_heads, tgt_seq, src_seq]
att = probs @ v # [batch * num_heads, tgt_seq, hidden / num_heads]
# [b * heads, s, h / heads] -> [b * heads, h / heads, s] -> [b, h, s] -> [b, s, h]
att = att.transpose(1, 2).reshape(bs, -1, attending_seq).transpose(1, 2).contiguous()
att = self.mix(att)
if return_attention:
return att, probs
return att