Spaces:
Runtime error
Runtime error
File size: 5,485 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from typing import Optional
import torch
from torch import Tensor
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
prob_check,
moving_sum,
)
def expected_alignment_from_p_choose(
p_choose: Tensor,
padding_mask: Optional[Tensor] = None,
eps: float = 1e-6
):
"""
Calculating expected alignment for from stepwise probability
Reference:
Online and Linear-Time Attention by Enforcing Monotonic Alignments
https://arxiv.org/pdf/1704.00784.pdf
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz, tgt_len, src_len
"""
prob_check(p_choose)
# p_choose: bsz, tgt_len, src_len
bsz, tgt_len, src_len = p_choose.size()
dtype = p_choose.dtype
p_choose = p_choose.float()
if padding_mask is not None:
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0)
# cumprod_1mp : bsz, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)
alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
alpha_0[:, :, 0] = 1.0
previous_alpha = [alpha_0]
for i in range(tgt_len):
# p_choose: bsz , tgt_len, src_len
# cumprod_1mp_clamp : bsz, tgt_len, src_len
# previous_alpha[i]: bsz, 1, src_len
# alpha_i: bsz, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1
)
).clamp(0, 1.0)
previous_alpha.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_alpha[1:], dim=1)
# Mix precision to prevent overflow for fp16
alpha = alpha.type(dtype)
prob_check(alpha)
return alpha
def expected_soft_attention(
alpha: Tensor,
soft_energy: Tensor,
padding_mask: Optional[Tensor] = None,
chunk_size: Optional[int] = None,
eps: float = 1e-10
):
"""
Function to compute expected soft attention for
monotonic infinite lookback attention from
expected alignment and soft energy.
Reference:
Monotonic Chunkwise Attention
https://arxiv.org/abs/1712.05382
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
https://arxiv.org/abs/1906.05218
alpha: bsz, tgt_len, src_len
soft_energy: bsz, tgt_len, src_len
padding_mask: bsz, src_len
left_padding: bool
"""
if padding_mask is not None:
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
soft_energy = soft_energy.masked_fill(
padding_mask.unsqueeze(1), -float("inf")
)
prob_check(alpha)
dtype = alpha.dtype
alpha = alpha.float()
soft_energy = soft_energy.float()
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy) + eps
if chunk_size is not None:
# Chunkwise
beta = (
exp_soft_energy
* moving_sum(
alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)),
1, chunk_size
)
)
else:
# Infinite lookback
# Notice that infinite lookback is a special case of chunkwise
# where chunksize = inf
inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2))
beta = (
exp_soft_energy
* torch.cumsum(inner_items.flip(dims=[2]), dim=2)
.flip(dims=[2])
)
if padding_mask is not None:
beta = beta.masked_fill(
padding_mask.unsqueeze(1).to(torch.bool), 0.0)
# Mix precision to prevent overflow for fp16
beta = beta.type(dtype)
beta = beta.clamp(0, 1)
prob_check(beta)
return beta
def mass_preservation(
alpha: Tensor,
padding_mask: Optional[Tensor] = None,
left_padding: bool = False
):
"""
Function to compute the mass perservation for alpha.
This means that the residual weights of alpha will be assigned
to the last token.
Reference:
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
https://arxiv.org/abs/1906.05218
alpha: bsz, tgt_len, src_len
padding_mask: bsz, src_len
left_padding: bool
"""
prob_check(alpha)
if padding_mask is not None:
if not left_padding:
assert not padding_mask[:, 0].any(), (
"Find padding on the beginning of the sequence."
)
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
if left_padding or padding_mask is None:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1)
alpha[:, :, -1] = residuals
else:
# right padding
_, tgt_len, src_len = alpha.size()
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1)
src_lens = src_len - padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals)
prob_check(alpha)
return alpha
|