File size: 7,370 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from functools import partial

import torch
from torch import Tensor
import math
import torch.nn.functional as F

from . import register_monotonic_attention
from .monotonic_multihead_attention import (
    MonotonicAttention,
    MonotonicInfiniteLookbackAttention,
    WaitKAttention
)
from typing import Dict, Optional


def fixed_pooling_monotonic_attention(monotonic_attention):
    def create_model(monotonic_attention, klass):
        class FixedStrideMonotonicAttention(monotonic_attention):
            def __init__(self, args):
                self.waitk_lagging = 0
                self.num_heads = 0
                self.noise_mean = 0.0
                self.noise_var = 0.0
                super().__init__(args)
                self.pre_decision_type = args.fixed_pre_decision_type
                self.pre_decision_ratio = args.fixed_pre_decision_ratio
                self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold
                assert self.pre_decision_ratio > 1

                if args.fixed_pre_decision_type == "average":
                    self.pooling_layer = torch.nn.AvgPool1d(
                        kernel_size=self.pre_decision_ratio,
                        stride=self.pre_decision_ratio,
                        ceil_mode=True,
                    )
                elif args.fixed_pre_decision_type == "last":

                    def last(key):
                        if key.size(2) < self.pre_decision_ratio:
                            return key
                        else:
                            k = key[
                                :,
                                :,
                                self.pre_decision_ratio - 1:: self.pre_decision_ratio,
                            ].contiguous()
                            if key.size(-1) % self.pre_decision_ratio != 0:
                                k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous()
                            return k

                    self.pooling_layer = last
                else:
                    raise NotImplementedError

            @staticmethod
            def add_args(parser):
                super(
                    FixedStrideMonotonicAttention, FixedStrideMonotonicAttention
                ).add_args(parser)
                parser.add_argument(
                    "--fixed-pre-decision-ratio",
                    type=int,
                    required=True,
                    help=(
                        "Ratio for the fixed pre-decision,"
                        "indicating how many encoder steps will start"
                        "simultaneous decision making process."
                    ),
                )
                parser.add_argument(
                    "--fixed-pre-decision-type",
                    default="average",
                    choices=["average", "last"],
                    help="Pooling type",
                )
                parser.add_argument(
                    "--fixed-pre-decision-pad-threshold",
                    type=float,
                    default=0.3,
                    help="If a part of the sequence has pad"
                    ",the threshold the pooled part is a pad.",
                )

            def insert_zeros(self, x):
                bsz_num_heads, tgt_len, src_len = x.size()
                stride = self.pre_decision_ratio
                weight = F.pad(torch.ones(1, 1, 1).to(x), (stride - 1, 0))
                x_upsample = F.conv_transpose1d(
                    x.view(-1, src_len).unsqueeze(1),
                    weight,
                    stride=stride,
                    padding=0,
                )
                return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1)

            def p_choose(
                self,
                query: Optional[Tensor],
                key: Optional[Tensor],
                key_padding_mask: Optional[Tensor] = None,
                incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
            ):
                assert key is not None
                assert query is not None
                src_len = key.size(0)
                tgt_len = query.size(0)
                batch_size = query.size(1)

                key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2)

                if key_padding_mask is not None:
                    key_padding_mask_pool = (
                        self.pooling_layer(key_padding_mask.unsqueeze(0).float())
                        .squeeze(0)
                        .gt(self.pre_decision_pad_threshold)
                    )
                    # Make sure at least one element is not pad
                    key_padding_mask_pool[:, 0] = 0
                else:
                    key_padding_mask_pool = None

                if incremental_state is not None:
                    # The floor instead of ceil is used for inference
                    # But make sure the length key_pool at least 1
                    if (
                        max(1, math.floor(key.size(0) / self.pre_decision_ratio))
                    ) < key_pool.size(0):
                        key_pool = key_pool[:-1]
                        if key_padding_mask_pool is not None:
                            key_padding_mask_pool = key_padding_mask_pool[:-1]

                p_choose_pooled = self.p_choose_from_qk(
                    query,
                    key_pool,
                    key_padding_mask_pool,
                    incremental_state=incremental_state,
                )

                # Upsample, interpolate zeros
                p_choose = self.insert_zeros(p_choose_pooled)

                if p_choose.size(-1) < src_len:
                    # Append zeros if the upsampled p_choose is shorter than src_len
                    p_choose = torch.cat(
                        [
                            p_choose,
                            torch.zeros(
                                p_choose.size(0),
                                tgt_len,
                                src_len - p_choose.size(-1)
                            ).to(p_choose)
                        ],
                        dim=2
                    )
                else:
                    # can be larger than src_len because we used ceil before
                    p_choose = p_choose[:, :, :src_len]
                    p_choose[:, :, -1] = p_choose_pooled[:, :, -1]

                assert list(p_choose.size()) == [
                    batch_size * self.num_heads,
                    tgt_len,
                    src_len,
                ]

                return p_choose

        FixedStrideMonotonicAttention.__name__ = klass.__name__
        return FixedStrideMonotonicAttention

    return partial(create_model, monotonic_attention)


@register_monotonic_attention("waitk_fixed_pre_decision")
@fixed_pooling_monotonic_attention(WaitKAttention)
class WaitKAttentionFixedStride:
    pass


@register_monotonic_attention("hard_aligned_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicAttention)
class MonotonicAttentionFixedStride:
    pass


@register_monotonic_attention("infinite_lookback_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicInfiniteLookbackAttention)
class MonotonicInfiniteLookbackAttentionFixedStride:
    pass