JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
13.5 kB
import argparse
import unittest
from typing import Any, Dict
import torch
from examples.simultaneous_translation.models import (
transformer_monotonic_attention
)
from tests.test_roberta import FakeTask
DEFAULT_CONFIG = {
"attention_eps": 1e-6,
"mass_preservation": True,
"noise_type": "flat",
"noise_mean": 0.0,
"noise_var": 1.0,
"energy_bias_init": -2,
"energy_bias": True
}
PAD_INDEX = 1
def generate_config(overrides_kv):
new_dict = {key: value for key, value in DEFAULT_CONFIG.items()}
for key, value in overrides_kv.items():
new_dict[key] = value
return new_dict
def make_sample_with_padding(longer_src=False) -> Dict[str, Any]:
tokens_1 = torch.LongTensor(
[
[2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2],
[
2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2,
PAD_INDEX, PAD_INDEX
],
]
)
tokens_2 = torch.LongTensor(
[
[2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX],
[2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX]
]
)
if longer_src:
src_tokens = tokens_1[:, 1:]
prev_output_tokens = tokens_2
else:
src_tokens = tokens_2[:, 1:8]
prev_output_tokens = tokens_1
src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long()
sample = {
"net_input": {
"src_tokens": src_tokens,
"prev_output_tokens": prev_output_tokens,
"src_lengths": src_lengths,
},
"target": prev_output_tokens[:, 1:],
}
return sample
def build_transformer_monotonic_attention(**extra_args: Any):
overrides = {
# Use characteristics dimensions
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"decoder_embed_dim": 12,
"decoder_ffn_embed_dim": 14,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
}
overrides.update(extra_args)
# Overrides the defaults from the parser
args = argparse.Namespace(**overrides)
transformer_monotonic_attention.monotonic_tiny_architecture(args)
torch.manual_seed(0)
task = FakeTask(args)
return (
transformer_monotonic_attention
.TransformerModelSimulTrans
.build_model(args, task)
)
def expected_alignment_formula(
p_choose,
mass_perservation=True,
padding_mask=None
):
# Online and Linear-Time Attention by Enforcing Monotonic Alignments
# https://arxiv.org/pdf/1704.00784.pdf
# Eq 18, 19
bsz, tgt_len, src_len = p_choose.size()
alpha = torch.zeros_like(p_choose)
if padding_mask is not None:
bsz_pad = padding_mask.size(0)
num_heads = int(bsz / bsz_pad)
padding_mask = (
padding_mask
.unsqueeze(1)
.expand([bsz_pad, num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0)
for bsz_i in range(bsz):
for i in range(tgt_len):
for j in range(src_len):
if i == 0:
if j == 0:
# First source token
alpha[bsz_i, i, j] = p_choose[bsz_i, i, j]
else:
# First target token
alpha[bsz_i, i, j] = (
p_choose[bsz_i, i, j]
* torch.prod(
1 - p_choose[bsz_i, i, :j]
)
)
else:
alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j]
for k in range(j):
alpha[bsz_i, i, j] += (
alpha[bsz_i, i - 1, k]
* torch.prod(
1 - p_choose[bsz_i, i, k:j]
)
)
alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j]
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0)
if mass_perservation:
alpha = mass_perservation_formula(alpha, False, padding_mask)
return alpha
def mass_perservation_formula(alpha, left_padding=False, padding_mask=None):
if padding_mask is None or alpha.size(-1) == 1:
if alpha.size(-1) > 1:
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1)
return alpha
src_lens = (padding_mask.logical_not()).sum(dim=1).long()
bsz, tgt_len, src_len = alpha.size()
assert (
not left_padding
or (left_padding and (not padding_mask[:, 0].any()))
)
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0)
for bsz_i in range(bsz):
if left_padding:
alpha[bsz_i, :, -1] = (
1 - alpha[bsz_i, :, :-1].sum(dim=-1)
)
else:
alpha[bsz_i, :, src_lens[bsz_i] - 1] = (
1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1)
)
return alpha
def expected_soft_attention_formula(
alpha,
soft_energy,
padding_mask=None,
chunksize=1e10,
):
# Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
# https://arxiv.org/pdf/1906.05218.pdf
# Eq 14
# Monotonic Chunkwise Attention
# https://arxiv.org/abs/1712.05382
# Eq 17
bsz, tgt_len, src_len = alpha.size()
beta = torch.zeros_like(alpha)
if padding_mask is not None:
bsz_pad = padding_mask.size(0)
num_heads = int(bsz / bsz_pad)
# Expanding for potential head dimension
padding_mask = (
padding_mask
.unsqueeze(1)
.expand([bsz_pad, num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf'))
for bsz_i in range(bsz):
for i in range(tgt_len):
for j in range(src_len):
for k in range(j, min([src_len, j + chunksize])):
if not padding_mask[bsz_i, j]:
beta[bsz_i, i, j] += (
alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j])
/ torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1]))
)
return beta
class MonotonicAttentionTestAbstractClass(object):
def test_forward(self):
sample = make_sample_with_padding()
out, _ = self.model.forward(**sample["net_input"])
loss = out.sum()
loss.backward()
def test_p_choose(self):
sample = make_sample_with_padding()
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
self.assertTrue(p_choose.le(1.0).all())
self.assertTrue(p_choose.ge(0.0).all())
def test_expected_alignment(self):
for longer_src in [True, False]:
sample = make_sample_with_padding(longer_src)
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
alpha_system = item["alpha"]
self.assertTrue(p_choose.size() == alpha_system.size())
bsz, num_head, tgt_len, src_len = alpha_system.size()
alpha_system = alpha_system.view(-1, tgt_len, src_len)
p_choose = p_choose.view(-1, tgt_len, src_len)
alpha_real = expected_alignment_formula(
p_choose,
self.model.decoder.layers[0].encoder_attn.mass_preservation,
sample["net_input"]["src_tokens"].eq(PAD_INDEX)
)
self.assertTrue(
torch.abs(alpha_system - alpha_real).le(5e-5).all(),
)
class HardMonotonicAttentionTestCase(
unittest.TestCase,
MonotonicAttentionTestAbstractClass
):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config({"simul_type": "hard_aligned"})
)
class InfiniteLookbackTestCase(
unittest.TestCase,
MonotonicAttentionTestAbstractClass
):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "infinite_lookback"
}
)
)
self.model.train()
def test_fp16_for_long_input(self):
sample = {
"net_input": {
"src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0),
"prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0),
"src_lengths": torch.LongTensor([1000]).cuda(),
},
"target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda()
}
self.model.cuda().half()
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
for key in ["p_choose", "alpha", "beta", "soft_energy"]:
self.assertFalse(torch.isnan(item[key]).any())
def test_expected_attention(self):
for longer_src in [True, False]:
sample = make_sample_with_padding(longer_src)
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
alpha_system = item["alpha"]
beta_system = item["beta"]
soft_energy_system = item["soft_energy"]
self.assertTrue(beta_system.size() == alpha_system.size())
self.assertTrue(p_choose.size() == alpha_system.size())
bsz, num_head, tgt_len, src_len = alpha_system.size()
alpha_system = alpha_system.view(-1, tgt_len, src_len)
beta_system = beta_system.view(-1, tgt_len, src_len)
p_choose = p_choose.view(-1, tgt_len, src_len)
soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len)
alpha_real = expected_alignment_formula(
p_choose,
self.model.decoder.layers[0].encoder_attn.mass_preservation,
sample["net_input"]["src_tokens"].eq(PAD_INDEX)
)
beta_real = expected_soft_attention_formula(
alpha_real,
soft_energy_system,
sample["net_input"]["src_tokens"].eq(PAD_INDEX),
chunksize=getattr(
self.model.decoder.layers[0].encoder_attn,
"chunk_size",
int(1e10)
)
)
self.assertTrue(
torch.abs(beta_system - beta_real).le(1e-5).all(),
)
class ChunkwiswTestCase(
InfiniteLookbackTestCase
):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "chunkwise",
"mocha_chunk_size": 3
}
)
)
class WaitkTestCase(InfiniteLookbackTestCase):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "waitk",
"waitk_lagging": 3,
}
)
)
def check_waitk(self, p_choose, lagging, padding_mask):
bsz, tgt_len, src_len = p_choose.size()
for bsz_i in range(bsz):
for i in range(tgt_len):
for j in range(src_len):
if not padding_mask[bsz_i, j]:
if j - i == lagging - 1:
self.assertTrue(p_choose[bsz_i, i, j] == 1)
else:
self.assertTrue(p_choose[bsz_i, i, j] == 0)
def test_waitk_p_choose(self):
for longer_src in [True, False]:
for k in [1, 3, 10, 20, 100]:
sample = make_sample_with_padding(longer_src)
model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "waitk",
"waitk_lagging": k,
}
)
)
model.train()
_, extra_out = model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
bsz, num_heads, tgt_len, src_len = p_choose.size()
padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX)
padding_mask = (
padding_mask
.unsqueeze(1)
.expand([bsz, num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len)
self.check_waitk(p_choose, k, padding_mask)