JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
5.49 kB
from typing import Optional
import torch
from torch import Tensor
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
prob_check,
moving_sum,
)
def expected_alignment_from_p_choose(
p_choose: Tensor,
padding_mask: Optional[Tensor] = None,
eps: float = 1e-6
):
"""
Calculating expected alignment for from stepwise probability
Reference:
Online and Linear-Time Attention by Enforcing Monotonic Alignments
https://arxiv.org/pdf/1704.00784.pdf
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz, tgt_len, src_len
"""
prob_check(p_choose)
# p_choose: bsz, tgt_len, src_len
bsz, tgt_len, src_len = p_choose.size()
dtype = p_choose.dtype
p_choose = p_choose.float()
if padding_mask is not None:
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0)
# cumprod_1mp : bsz, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)
alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
alpha_0[:, :, 0] = 1.0
previous_alpha = [alpha_0]
for i in range(tgt_len):
# p_choose: bsz , tgt_len, src_len
# cumprod_1mp_clamp : bsz, tgt_len, src_len
# previous_alpha[i]: bsz, 1, src_len
# alpha_i: bsz, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1
)
).clamp(0, 1.0)
previous_alpha.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_alpha[1:], dim=1)
# Mix precision to prevent overflow for fp16
alpha = alpha.type(dtype)
prob_check(alpha)
return alpha
def expected_soft_attention(
alpha: Tensor,
soft_energy: Tensor,
padding_mask: Optional[Tensor] = None,
chunk_size: Optional[int] = None,
eps: float = 1e-10
):
"""
Function to compute expected soft attention for
monotonic infinite lookback attention from
expected alignment and soft energy.
Reference:
Monotonic Chunkwise Attention
https://arxiv.org/abs/1712.05382
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
https://arxiv.org/abs/1906.05218
alpha: bsz, tgt_len, src_len
soft_energy: bsz, tgt_len, src_len
padding_mask: bsz, src_len
left_padding: bool
"""
if padding_mask is not None:
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
soft_energy = soft_energy.masked_fill(
padding_mask.unsqueeze(1), -float("inf")
)
prob_check(alpha)
dtype = alpha.dtype
alpha = alpha.float()
soft_energy = soft_energy.float()
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy) + eps
if chunk_size is not None:
# Chunkwise
beta = (
exp_soft_energy
* moving_sum(
alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)),
1, chunk_size
)
)
else:
# Infinite lookback
# Notice that infinite lookback is a special case of chunkwise
# where chunksize = inf
inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2))
beta = (
exp_soft_energy
* torch.cumsum(inner_items.flip(dims=[2]), dim=2)
.flip(dims=[2])
)
if padding_mask is not None:
beta = beta.masked_fill(
padding_mask.unsqueeze(1).to(torch.bool), 0.0)
# Mix precision to prevent overflow for fp16
beta = beta.type(dtype)
beta = beta.clamp(0, 1)
prob_check(beta)
return beta
def mass_preservation(
alpha: Tensor,
padding_mask: Optional[Tensor] = None,
left_padding: bool = False
):
"""
Function to compute the mass perservation for alpha.
This means that the residual weights of alpha will be assigned
to the last token.
Reference:
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
https://arxiv.org/abs/1906.05218
alpha: bsz, tgt_len, src_len
padding_mask: bsz, src_len
left_padding: bool
"""
prob_check(alpha)
if padding_mask is not None:
if not left_padding:
assert not padding_mask[:, 0].any(), (
"Find padding on the beginning of the sequence."
)
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
if left_padding or padding_mask is None:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1)
alpha[:, :, -1] = residuals
else:
# right padding
_, tgt_len, src_len = alpha.size()
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1)
src_lens = src_len - padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals)
prob_check(alpha)
return alpha