homemade_lo_vi / modules /multi_head_attention.py
moiduy04's picture
Upload 18 files
bc1ada8
raw
history blame
2.48 kB
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from modules.wrapper import Linear
from modules.dot_product_attention import ScaledDotProductAttention
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention (section 3.2.2)
Args:
- d_model (int): dimension of model
- num_heads (int): number of heads
- dropout_p (float): probability of dropout
Inputs:
- query (batch, seq_len, d_model):
- key (batch, seq_len, d_model):
- value (batch, seq_len, d_model):
- mask ():
Output: (Tensor, Tensor):
- context ()
- attn (): Attention matrix for visualization.
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout_p: int,
) -> None:
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model % num_heads should be 0"
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.W_query = Linear(d_model, d_model)
self.W_key = Linear(d_model, d_model)
self.W_value = Linear(d_model, d_model)
# self.W_output = Linear(d_model, d_model)
self.scaled_dot_attn = ScaledDotProductAttention(d_model, dropout_p)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
batch_size = query.shape[0]
# original: (batch, seq_len, d_model)
# --forward--> (batch, seq_len, d_model)
# --view--> (batch, seq_len, num_heads, d_head)
# --transpose--> (batch, num_heads, seq_len, d_head)
query = self.W_query(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
key = self.W_key(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
value = self.W_value(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
context, attn = self.scaled_dot_attn(query, key, value, mask)
# (batch, num_heads, seq_len, d_head)
# --transpose--> (batch, seq_len, num_heads, d_head)
# --view--> (batch, seq_len, d_model)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# context = self.W_output(context)
return context, attn