File size: 4,319 Bytes
d36d50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import (
    Optional,
)
import math

import torch as T
from torch import nn
from torch.nn import functional as F

import opt_einsum as oe

from torch import Tensor

einsum = oe.contract


def masked_softmax(xs: Tensor, mask: Tensor, dim: int = -1, eps=1e-12):
    xs = xs.masked_fill(~mask, -1e9)
    xs = F.softmax(xs, dim=dim)
    return xs

class Attention(nn.Module):
    def __init__(
            self,
            kind: str,
            query_dim: int,
            input_dim: int,
            output_dim: int = None,
            activation: str = 'auto',
            scaled = True,
    ):
        super().__init__()
        assert kind in [
            'dot',
            'linear',
        ]

        self.kind = kind
        self.Dq = query_dim
        self.Din = input_dim
        self.Dout = output_dim or self.Din
        self.activation = 'auto'
        self.scaled = scaled

        self.Wq_ = nn.Linear(self.Dq, self.Din)
        self.Wk_ = nn.Linear(self.Din, self.Din)
        self.Wv_ = nn.Linear(self.Din, self.Dout)
        self.Wz_ = nn.Linear(self.Din, self.Dout)

    def forward(
            self,
            query: Tensor,
            data: Tensor,
            content_mask: Optional[Tensor] = None,
            prejudice_mask: Optional[Tensor] = None,
    ):
        #^ query: [b, ts, tw, dq]
        #^ data: [b, ts, di]
        #^ content_mask: [b, ts, tw]
        #^ prejudice_mask: [b, ts, ts]
        #^ => output: [b, ts, tw, dz]

        dimB, dimS, dimW, dimI = query.shape

        # TODO: Optimize out the [ts, ts, *] intermediate
        qs = self.Wq_(query)
        ks = self.Wk_(data)
        vs = self.Wv_(data)

        if content_mask is not None:
            words_mask = content_mask.any(2)
            #^ words_mask : [b, ts]
        else:
            words_mask = qs.new_ones((dimB, dimS))

        if self.kind == 'linear':
            # Ref: https://twitter.com/francoisfleuret/status/1267455240007188486
            assert prejudice_mask is None, "Linear mode does not support prejudice_mask."
            assert content_mask is not None, "Linear mode requires a content_mask."
            qs = T.relu(qs) * content_mask.unsqueeze(3)
            #^ qs: [bswi]
            ks = T.relu(ks) * words_mask.unsqueeze(2)
            #^ ks: [bsi]
            vks = einsum("bsi, bsz -> bzi", ks, vs)
            #^ vks : [b, dz, di]
            zs = einsum("bswi, bzi -> bswz", qs, vks)
            #^ zs : [b, ts, tw, dz]
            if self.scaled:
                ks = ks.sum(1)
                #^ ks: [bi]
                denom = einsum("bswi, bi -> bsw", qs, ks) + 1e-9
                zs = zs / denom

        elif self.kind == 'dot':
            # Ref: https://arxiv.org/abs/1706.03762
            # s=ts in q
            # S=ts in ks,vs
            att_map = einsum("bqwi, bki -> bqkw", qs, ks)
            #^ [b, ts:q, ts:k, tw]
            if self.scaled == 'seqlen':
                att_map_ndim = len(att_map.shape) - 1
                norm_coeff = words_mask.sum(1).view(-1, *([1] * att_map_ndim))
                #^ [b, _, _, _]
                att_map = att_map / T.sqrt(norm_coeff.float())
            else:
                att_map = att_map / math.sqrt(self.Din)

            if content_mask is None and prejudice_mask is None:
                att_map = F.softmax(att_map, dim=2)
            else:
                if content_mask is None:
                    assert prejudice_mask is not None # !for mypy
                    qk_mask = prejudice_mask.unsqueeze(3)
                    #^ qk_mask : [b, ts:q, ts:k, tw^]
                elif prejudice_mask is None:
                    qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
                    #^ qk_mask : [b, ts:q, ts:k^, tw]
                else:
                    qk_mask = words_mask.unsqueeze(1).unsqueeze(3)
                    # qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
                    qk_mask = qk_mask * prejudice_mask.unsqueeze(3)
                    #^ qk_mask : [b, ts:q^, ts:k, tw]

                att_map = masked_softmax(att_map, qk_mask.bool(), dim=2)

            #^ att_map : [b, ts:q, ts:k, tw]
            zs = einsum("bqkw, bkz -> bqwz", att_map, vs)

        zs = self.Wz_(zs)
        return zs, att_map