File size: 3,246 Bytes
1d5604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
# coding=utf-8

import torch
import torch.nn as nn


def checkpoint(module, *args, **kwargs):
    dummy = torch.empty(1, requires_grad=True)
    return torch.utils.checkpoint.checkpoint(lambda d, *a, **k: module(*a, **k), dummy, *args, **kwargs)


class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.attention = nn.MultiheadAttention(args.hidden_size, args.n_attention_heads, args.dropout_transformer_attention)
        self.dropout = nn.Dropout(args.dropout_transformer)

    def forward(self, q_input, kv_input, mask=None):
        output, _ = self.attention(q_input, kv_input, kv_input, mask, need_weights=False)
        output = self.dropout(output)
        return output


class FeedForward(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.f = nn.Sequential(
            nn.Linear(args.hidden_size, args.hidden_size_ff),
            self._get_activation_f(args.activation),
            nn.Dropout(args.dropout_transformer),
            nn.Linear(args.hidden_size_ff, args.hidden_size),
            nn.Dropout(args.dropout_transformer),
        )

    def forward(self, x):
        return self.f(x)

    def _get_activation_f(self, activation: str):
        return {"relu": nn.ReLU, "gelu": nn.GELU}[activation]()


class DecoderLayer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.self_f = Attention(args)
        #self.cross_f = Attention(args)
        self.feedforward_f = FeedForward(args)

        self.pre_self_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity()
        #self.pre_cross_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity()
        self.pre_feedforward_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity()
        self.post_self_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size)
        #self.post_cross_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size)
        self.post_feedforward_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size)

    def forward(self, x, encoder_output, x_mask, encoder_mask):
        x_ = self.pre_self_norm(x)
        x = self.post_self_norm(x + self.self_f(x_, x_, x_mask))

        #x_ = self.pre_cross_norm(x)
        #x = self.post_cross_norm(x + self.cross_f(x_, encoder_output, encoder_mask))

        x_ = self.pre_feedforward_norm(x)
        x = self.post_feedforward_norm(x + self.feedforward_f(x_))

        return x


class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layers)])

    def forward(self, target, encoder, target_mask, encoder_mask):
        target = target.transpose(0, 1)  # shape: (T, B, D)
        encoder = encoder.transpose(0, 1)  # shape: (T, B, D)

        for layer in self.layers[:-1]:
            target = checkpoint(layer, target, encoder, target_mask, encoder_mask)
        target = self.layers[-1](target, encoder, target_mask, encoder_mask)  # don't checkpoint due to grad_norm
        target = target.transpose(0, 1)  # shape: (B, T, D)

        return target