Spaces:
Running
Running
from functools import partial | |
import torch | |
from torch import Tensor | |
import math | |
import torch.nn.functional as F | |
from . import register_monotonic_attention | |
from .monotonic_multihead_attention import ( | |
MonotonicAttention, | |
MonotonicInfiniteLookbackAttention, | |
WaitKAttention | |
) | |
from typing import Dict, Optional | |
def fixed_pooling_monotonic_attention(monotonic_attention): | |
def create_model(monotonic_attention, klass): | |
class FixedStrideMonotonicAttention(monotonic_attention): | |
def __init__(self, args): | |
self.waitk_lagging = 0 | |
self.num_heads = 0 | |
self.noise_mean = 0.0 | |
self.noise_var = 0.0 | |
super().__init__(args) | |
self.pre_decision_type = args.fixed_pre_decision_type | |
self.pre_decision_ratio = args.fixed_pre_decision_ratio | |
self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold | |
assert self.pre_decision_ratio > 1 | |
if args.fixed_pre_decision_type == "average": | |
self.pooling_layer = torch.nn.AvgPool1d( | |
kernel_size=self.pre_decision_ratio, | |
stride=self.pre_decision_ratio, | |
ceil_mode=True, | |
) | |
elif args.fixed_pre_decision_type == "last": | |
def last(key): | |
if key.size(2) < self.pre_decision_ratio: | |
return key | |
else: | |
k = key[ | |
:, | |
:, | |
self.pre_decision_ratio - 1:: self.pre_decision_ratio, | |
].contiguous() | |
if key.size(-1) % self.pre_decision_ratio != 0: | |
k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous() | |
return k | |
self.pooling_layer = last | |
else: | |
raise NotImplementedError | |
def add_args(parser): | |
super( | |
FixedStrideMonotonicAttention, FixedStrideMonotonicAttention | |
).add_args(parser) | |
parser.add_argument( | |
"--fixed-pre-decision-ratio", | |
type=int, | |
required=True, | |
help=( | |
"Ratio for the fixed pre-decision," | |
"indicating how many encoder steps will start" | |
"simultaneous decision making process." | |
), | |
) | |
parser.add_argument( | |
"--fixed-pre-decision-type", | |
default="average", | |
choices=["average", "last"], | |
help="Pooling type", | |
) | |
parser.add_argument( | |
"--fixed-pre-decision-pad-threshold", | |
type=float, | |
default=0.3, | |
help="If a part of the sequence has pad" | |
",the threshold the pooled part is a pad.", | |
) | |
def insert_zeros(self, x): | |
bsz_num_heads, tgt_len, src_len = x.size() | |
stride = self.pre_decision_ratio | |
weight = F.pad(torch.ones(1, 1, 1).to(x), (stride - 1, 0)) | |
x_upsample = F.conv_transpose1d( | |
x.view(-1, src_len).unsqueeze(1), | |
weight, | |
stride=stride, | |
padding=0, | |
) | |
return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) | |
def p_choose( | |
self, | |
query: Optional[Tensor], | |
key: Optional[Tensor], | |
key_padding_mask: Optional[Tensor] = None, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
): | |
assert key is not None | |
assert query is not None | |
src_len = key.size(0) | |
tgt_len = query.size(0) | |
batch_size = query.size(1) | |
key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) | |
if key_padding_mask is not None: | |
key_padding_mask_pool = ( | |
self.pooling_layer(key_padding_mask.unsqueeze(0).float()) | |
.squeeze(0) | |
.gt(self.pre_decision_pad_threshold) | |
) | |
# Make sure at least one element is not pad | |
key_padding_mask_pool[:, 0] = 0 | |
else: | |
key_padding_mask_pool = None | |
if incremental_state is not None: | |
# The floor instead of ceil is used for inference | |
# But make sure the length key_pool at least 1 | |
if ( | |
max(1, math.floor(key.size(0) / self.pre_decision_ratio)) | |
) < key_pool.size(0): | |
key_pool = key_pool[:-1] | |
if key_padding_mask_pool is not None: | |
key_padding_mask_pool = key_padding_mask_pool[:-1] | |
p_choose_pooled = self.p_choose_from_qk( | |
query, | |
key_pool, | |
key_padding_mask_pool, | |
incremental_state=incremental_state, | |
) | |
# Upsample, interpolate zeros | |
p_choose = self.insert_zeros(p_choose_pooled) | |
if p_choose.size(-1) < src_len: | |
# Append zeros if the upsampled p_choose is shorter than src_len | |
p_choose = torch.cat( | |
[ | |
p_choose, | |
torch.zeros( | |
p_choose.size(0), | |
tgt_len, | |
src_len - p_choose.size(-1) | |
).to(p_choose) | |
], | |
dim=2 | |
) | |
else: | |
# can be larger than src_len because we used ceil before | |
p_choose = p_choose[:, :, :src_len] | |
p_choose[:, :, -1] = p_choose_pooled[:, :, -1] | |
assert list(p_choose.size()) == [ | |
batch_size * self.num_heads, | |
tgt_len, | |
src_len, | |
] | |
return p_choose | |
FixedStrideMonotonicAttention.__name__ = klass.__name__ | |
return FixedStrideMonotonicAttention | |
return partial(create_model, monotonic_attention) | |
class WaitKAttentionFixedStride: | |
pass | |
class MonotonicAttentionFixedStride: | |
pass | |
class MonotonicInfiniteLookbackAttentionFixedStride: | |
pass | |