PolyFormer / fairseq /examples /simultaneous_translation /modules /monotonic_multihead_attention.py
jiang
init commit
650c5f6
raw
history blame
16.9 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import Tensor
import torch.nn as nn
from examples.simultaneous_translation.utils.p_choose_strategy import (
learnable_p_choose,
waitk_p_choose
)
from examples.simultaneous_translation.utils.monotonic_attention import (
expected_alignment_from_p_choose,
expected_soft_attention,
mass_preservation,
)
from fairseq.modules import MultiheadAttention
from . import register_monotonic_attention
from typing import Dict, Optional
@register_monotonic_attention("hard_aligned")
class MonotonicAttention(MultiheadAttention):
"""
Abstract class of monotonic attentions
"""
k_in_proj: Dict[str, nn.Linear]
q_in_proj: Dict[str, nn.Linear]
def __init__(self, args):
super().__init__(
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.soft_attention = False
self.eps = getattr(args, "attention_eps", True)
self.mass_preservation = getattr(args, "mass_preservation", True)
self.noise_type = args.noise_type
self.noise_mean = args.noise_mean
self.noise_var = args.noise_var
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True
else 0
)
self.k_in_proj = {"monotonic": self.k_proj}
self.q_in_proj = {"monotonic": self.q_proj}
self.chunk_size = None
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false",
dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true",
dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--noise-type', type=str, default="flat",
help='Type of discretness noise')
parser.add_argument('--energy-bias', action="store_true",
default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
def energy_from_qk(
self,
query: Tensor,
key: Tensor,
energy_type: str,
key_padding_mask: Optional[Tensor] = None,
bias: int = 0
):
"""
Compute energy from query and key
q_func_value is a tuple looks like
(q_proj_func, q_tensor)
q_tensor size: bsz, tgt_len, emb_dim
k_tensor size: bsz, src_len, emb_dim
key_padding_mask size: bsz, src_len
attn_mask: bsz, src_len
"""
length, bsz, _ = query.size()
q = self.q_in_proj[energy_type].forward(query)
q = (
q.contiguous()
.view(length, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q * self.scaling
length, bsz, _ = key.size()
k = self.k_in_proj[energy_type].forward(key)
k = (
k.contiguous()
.view(length, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
energy = torch.bmm(q, k.transpose(1, 2)) + bias
if key_padding_mask is not None:
energy = energy.masked_fill(
key_padding_mask.unsqueeze(1).to(torch.bool),
- float("inf")
)
return energy
def p_choose_from_qk(self, query, key, key_padding_mask, incremental_states=None):
monotonic_energy = self.energy_from_qk(
query,
key,
"monotonic",
key_padding_mask=key_padding_mask,
bias=self.energy_bias,
)
p_choose = learnable_p_choose(
monotonic_energy,
self.noise_mean,
self.noise_var,
self.training
)
return p_choose
def p_choose(self, query, key, key_padding_mask, incremental_states=None):
return self.p_choose_from_qk(self, query, key, key_padding_mask)
def monotonic_attention_process_infer(
self,
query: Optional[Tensor],
key: Optional[Tensor],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
):
"""
Monotonic attention at inference time
Notice that this function is designed for simuleval not sequence_generator
"""
assert query is not None
assert key is not None
if query.size(1) != 1:
raise RuntimeError(
"Simultaneous translation models don't support batch decoding."
)
# 1. compute stepwise probability
p_choose = self.p_choose(
query, key, None, incremental_state
).squeeze(1)
# 2. Compute the alpha
src_len = key.size(0)
# Maximum steps allows in this iteration
max_steps = src_len - 1 if self.mass_preservation else src_len
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# Step for each head
monotonic_step = monotonic_cache.get(
'head_step',
p_choose.new_zeros(1, self.num_heads).long()
)
assert monotonic_step is not None
finish_read = monotonic_step.eq(max_steps)
p_choose_i = torch.tensor(1)
while finish_read.sum().item() < self.num_heads:
# p_choose: self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: 1, self.num_heads
p_choose_i = (
p_choose.gather(
1,
monotonic_step
.clamp(0, src_len - 1),
)
)
read_one_step = (
(p_choose_i < 0.5)
.type_as(monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 0 means stay, finish reading
# 1 means leave, continue reading
monotonic_step += read_one_step
finish_read = monotonic_step.eq(max_steps) | (read_one_step == 0)
# p_choose at last steps
p_choose_i = (
p_choose.gather(
1,
monotonic_step
.clamp(0, src_len - 1),
)
)
monotonic_cache["head_step"] = monotonic_step
# Whether a head is looking for new input
monotonic_cache["head_read"] = (
monotonic_step.eq(max_steps) & (p_choose_i < 0.5)
)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
# 2. Update alpha
alpha = (
p_choose
.new_zeros([self.num_heads, src_len])
.scatter(
1,
(monotonic_step)
.view(self.num_heads, 1).clamp(0, src_len - 1),
1
)
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(monotonic_step == max_steps)
.view(self.num_heads, 1),
0
)
# 4. Compute Beta
if self.soft_attention:
monotonic_step = monotonic_step.t()
beta_mask = torch.arange(src_len).expand_as(alpha).gt(monotonic_step).unsqueeze(1)
# If it's soft attention just do softmax on current context
soft_energy = self.energy_from_qk(
query,
key,
"soft"
)
beta = torch.nn.functional.softmax(
soft_energy.masked_fill(beta_mask, -float("inf")), dim=-1
)
# It could happen that a head doesn't move at all
beta = beta.masked_fill(monotonic_step.eq(0).unsqueeze(1), 0)
else:
# If it's hard attention just select the last state
beta = alpha
return p_choose, alpha, beta
def monotonic_attention_process_train(
self,
query: Optional[Tensor],
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
):
"""
Calculating monotonic attention process for training
Including:
stepwise probability: p_choose
expected hard alignment: alpha
expected soft attention: beta
"""
assert query is not None
assert key is not None
# 1. compute stepwise probability
p_choose = self.p_choose_from_qk(query, key, key_padding_mask)
# 2. compute expected_alignment
alpha = expected_alignment_from_p_choose(
p_choose,
key_padding_mask,
eps=self.eps,
)
if self.mass_preservation:
alpha = mass_preservation(
alpha, key_padding_mask
)
# 3. compute expected soft attention (soft aligned model only)
if self.soft_attention:
soft_energy = self.energy_from_qk(
query,
key,
"soft",
key_padding_mask=None,
)
beta = expected_soft_attention(
alpha,
soft_energy,
padding_mask=key_padding_mask,
chunk_size=self.chunk_size,
eps=self.eps,
)
else:
beta = alpha
soft_energy = alpha
return p_choose, alpha, beta, soft_energy
def forward(
self,
query: Optional[Tensor],
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False,
):
"""
query: tgt_len, bsz, embed_dim
key: src_len, bsz, embed_dim
value: src_len, bsz, embed_dim
"""
assert attn_mask is None
assert query is not None
assert key is not None
assert value is not None
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
if key_padding_mask is not None:
assert not key_padding_mask[:, 0].any(), (
"Only right padding is supported."
)
key_padding_mask = (
key_padding_mask
.unsqueeze(1)
.expand([bsz, self.num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
if incremental_state is not None:
# Inference
(
p_choose, alpha, beta
) = self.monotonic_attention_process_infer(
query, key, incremental_state
)
soft_energy = beta
else:
# Train
(
p_choose, alpha, beta, soft_energy
) = self.monotonic_attention_process_train(
query, key, key_padding_mask
)
v = self.v_proj(value)
length, bsz, _ = v.size()
v = (
v.contiguous()
.view(length, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
attn = torch.bmm(beta.type_as(v), v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {
"p_choose": p_choose,
"alpha": alpha,
"beta": beta,
}
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
maybe_incremental_state = self.get_incremental_state(
incremental_state,
'monotonic',
)
if maybe_incremental_state is None:
typed_empty_dict: Dict[str, Optional[Tensor]] = {}
return typed_empty_dict
else:
return maybe_incremental_state
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]):
self.set_incremental_state(
incremental_state,
'monotonic',
buffer,
)
@register_monotonic_attention("infinite_lookback")
class MonotonicInfiniteLookbackAttention(
MonotonicAttention
):
def __init__(self, args):
super().__init__(args)
self.soft_attention = True
self.init_soft_attention()
def init_soft_attention(self):
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.k_in_proj["soft"] = self.k_proj_soft
self.q_in_proj["soft"] = self.q_proj_soft
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
@register_monotonic_attention("waitk")
class WaitKAttention(
MonotonicInfiniteLookbackAttention
):
"""
STACL: Simultaneous Translation with Implicit Anticipation and
Controllable Latency using Prefix-to-Prefix Framework
https://www.aclweb.org/anthology/P19-1289/
"""
def __init__(self, args):
super().__init__(args)
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert self.waitk_lagging > 0, (
f"Lagging has to been larger than 0, get {self.waitk_lagging}."
)
@staticmethod
def add_args(parser):
super(
MonotonicInfiniteLookbackAttention,
MonotonicInfiniteLookbackAttention
).add_args(parser)
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait K lagging"
)
def p_choose_from_qk(
self,
query: Optional[Tensor],
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
assert query is not None
assert key is not None
p_choose = waitk_p_choose(
tgt_len=query.size(0),
src_len=key.size(0),
bsz=query.size(1) * self.num_heads,
waitk_lagging=self.waitk_lagging,
key_padding_mask=key_padding_mask,
incremental_state=incremental_state,
)
return p_choose.to(query)
@register_monotonic_attention("chunkwise")
class ChunkwiseAttention(
MonotonicInfiniteLookbackAttention
):
def __init__(self, args):
super().__init__(args)
self.chunk_size = args.mocha_chunk_size
assert self.chunk_size > 1
@staticmethod
def add_args(parser):
super(
MonotonicInfiniteLookbackAttention
).add_args(parser)
parser.add_argument(
"--mocha-chunk-size", type=int,
required=True, help="Mocha chunk size"
)