File size: 2,480 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor

from modules.wrapper import Linear
from modules.dot_product_attention import ScaledDotProductAttention


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention (section 3.2.2)

    Args:
        - d_model (int): dimension of model
        - num_heads (int): number of heads
        - dropout_p (float): probability of dropout

    Inputs:
        - query (batch, seq_len, d_model):
        - key   (batch, seq_len, d_model):
        - value (batch, seq_len, d_model):
        - mask ():

    Output: (Tensor, Tensor):
        - context ()
        - attn (): Attention matrix for visualization.
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout_p: int,
    ) -> None:
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model % num_heads should be 0"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        
        self.W_query  = Linear(d_model, d_model)
        self.W_key    = Linear(d_model, d_model)
        self.W_value  = Linear(d_model, d_model)
        # self.W_output = Linear(d_model, d_model)

        self.scaled_dot_attn = ScaledDotProductAttention(d_model, dropout_p)


    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        batch_size = query.shape[0]

        # original: (batch, seq_len, d_model) 
        # --forward--> (batch, seq_len, d_model) 
        # --view--> (batch, seq_len, num_heads, d_head) 
        # --transpose--> (batch, num_heads, seq_len, d_head)
        query = self.W_query(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
        key   =     self.W_key(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
        value = self.W_value(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
        
        context, attn = self.scaled_dot_attn(query, key, value, mask)

        # (batch, num_heads, seq_len, d_head)
        # --transpose--> (batch, seq_len, num_heads, d_head)
        # --view--> (batch, seq_len, d_model)

        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # context = self.W_output(context)

        return context, attn