import argparse import unittest from typing import Any, Dict import torch from examples.simultaneous_translation.models import ( transformer_monotonic_attention ) from tests.test_roberta import FakeTask DEFAULT_CONFIG = { "attention_eps": 1e-6, "mass_preservation": True, "noise_type": "flat", "noise_mean": 0.0, "noise_var": 1.0, "energy_bias_init": -2, "energy_bias": True } PAD_INDEX = 1 def generate_config(overrides_kv): new_dict = {key: value for key, value in DEFAULT_CONFIG.items()} for key, value in overrides_kv.items(): new_dict[key] = value return new_dict def make_sample_with_padding(longer_src=False) -> Dict[str, Any]: tokens_1 = torch.LongTensor( [ [2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2], [ 2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2, PAD_INDEX, PAD_INDEX ], ] ) tokens_2 = torch.LongTensor( [ [2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX], [2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX] ] ) if longer_src: src_tokens = tokens_1[:, 1:] prev_output_tokens = tokens_2 else: src_tokens = tokens_2[:, 1:8] prev_output_tokens = tokens_1 src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long() sample = { "net_input": { "src_tokens": src_tokens, "prev_output_tokens": prev_output_tokens, "src_lengths": src_lengths, }, "target": prev_output_tokens[:, 1:], } return sample def build_transformer_monotonic_attention(**extra_args: Any): overrides = { # Use characteristics dimensions "encoder_embed_dim": 12, "encoder_ffn_embed_dim": 14, "decoder_embed_dim": 12, "decoder_ffn_embed_dim": 14, # Disable dropout so we have comparable tests. "dropout": 0, "attention_dropout": 0, "activation_dropout": 0, "encoder_layerdrop": 0, } overrides.update(extra_args) # Overrides the defaults from the parser args = argparse.Namespace(**overrides) transformer_monotonic_attention.monotonic_tiny_architecture(args) torch.manual_seed(0) task = FakeTask(args) return ( transformer_monotonic_attention .TransformerModelSimulTrans .build_model(args, task) ) def expected_alignment_formula( p_choose, mass_perservation=True, padding_mask=None ): # Online and Linear-Time Attention by Enforcing Monotonic Alignments # https://arxiv.org/pdf/1704.00784.pdf # Eq 18, 19 bsz, tgt_len, src_len = p_choose.size() alpha = torch.zeros_like(p_choose) if padding_mask is not None: bsz_pad = padding_mask.size(0) num_heads = int(bsz / bsz_pad) padding_mask = ( padding_mask .unsqueeze(1) .expand([bsz_pad, num_heads, src_len]) .contiguous() .view(-1, src_len) ) p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0) for bsz_i in range(bsz): for i in range(tgt_len): for j in range(src_len): if i == 0: if j == 0: # First source token alpha[bsz_i, i, j] = p_choose[bsz_i, i, j] else: # First target token alpha[bsz_i, i, j] = ( p_choose[bsz_i, i, j] * torch.prod( 1 - p_choose[bsz_i, i, :j] ) ) else: alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j] for k in range(j): alpha[bsz_i, i, j] += ( alpha[bsz_i, i - 1, k] * torch.prod( 1 - p_choose[bsz_i, i, k:j] ) ) alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j] alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) if mass_perservation: alpha = mass_perservation_formula(alpha, False, padding_mask) return alpha def mass_perservation_formula(alpha, left_padding=False, padding_mask=None): if padding_mask is None or alpha.size(-1) == 1: if alpha.size(-1) > 1: alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1) return alpha src_lens = (padding_mask.logical_not()).sum(dim=1).long() bsz, tgt_len, src_len = alpha.size() assert ( not left_padding or (left_padding and (not padding_mask[:, 0].any())) ) alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) for bsz_i in range(bsz): if left_padding: alpha[bsz_i, :, -1] = ( 1 - alpha[bsz_i, :, :-1].sum(dim=-1) ) else: alpha[bsz_i, :, src_lens[bsz_i] - 1] = ( 1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1) ) return alpha def expected_soft_attention_formula( alpha, soft_energy, padding_mask=None, chunksize=1e10, ): # Monotonic Infinite Lookback Attention for Simultaneous Machine Translation # https://arxiv.org/pdf/1906.05218.pdf # Eq 14 # Monotonic Chunkwise Attention # https://arxiv.org/abs/1712.05382 # Eq 17 bsz, tgt_len, src_len = alpha.size() beta = torch.zeros_like(alpha) if padding_mask is not None: bsz_pad = padding_mask.size(0) num_heads = int(bsz / bsz_pad) # Expanding for potential head dimension padding_mask = ( padding_mask .unsqueeze(1) .expand([bsz_pad, num_heads, src_len]) .contiguous() .view(-1, src_len) ) soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf')) for bsz_i in range(bsz): for i in range(tgt_len): for j in range(src_len): for k in range(j, min([src_len, j + chunksize])): if not padding_mask[bsz_i, j]: beta[bsz_i, i, j] += ( alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j]) / torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1])) ) return beta class MonotonicAttentionTestAbstractClass(object): def test_forward(self): sample = make_sample_with_padding() out, _ = self.model.forward(**sample["net_input"]) loss = out.sum() loss.backward() def test_p_choose(self): sample = make_sample_with_padding() _, extra_out = self.model.forward(**sample["net_input"]) for item in extra_out.attn_list: p_choose = item["p_choose"] self.assertTrue(p_choose.le(1.0).all()) self.assertTrue(p_choose.ge(0.0).all()) def test_expected_alignment(self): for longer_src in [True, False]: sample = make_sample_with_padding(longer_src) _, extra_out = self.model.forward(**sample["net_input"]) for item in extra_out.attn_list: p_choose = item["p_choose"] alpha_system = item["alpha"] self.assertTrue(p_choose.size() == alpha_system.size()) bsz, num_head, tgt_len, src_len = alpha_system.size() alpha_system = alpha_system.view(-1, tgt_len, src_len) p_choose = p_choose.view(-1, tgt_len, src_len) alpha_real = expected_alignment_formula( p_choose, self.model.decoder.layers[0].encoder_attn.mass_preservation, sample["net_input"]["src_tokens"].eq(PAD_INDEX) ) self.assertTrue( torch.abs(alpha_system - alpha_real).le(5e-5).all(), ) class HardMonotonicAttentionTestCase( unittest.TestCase, MonotonicAttentionTestAbstractClass ): def setUp(self): self.model = build_transformer_monotonic_attention( **generate_config({"simul_type": "hard_aligned"}) ) class InfiniteLookbackTestCase( unittest.TestCase, MonotonicAttentionTestAbstractClass ): def setUp(self): self.model = build_transformer_monotonic_attention( **generate_config( { "simul_type": "infinite_lookback" } ) ) self.model.train() def test_fp16_for_long_input(self): sample = { "net_input": { "src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), "prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), "src_lengths": torch.LongTensor([1000]).cuda(), }, "target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda() } self.model.cuda().half() _, extra_out = self.model.forward(**sample["net_input"]) for item in extra_out.attn_list: for key in ["p_choose", "alpha", "beta", "soft_energy"]: self.assertFalse(torch.isnan(item[key]).any()) def test_expected_attention(self): for longer_src in [True, False]: sample = make_sample_with_padding(longer_src) _, extra_out = self.model.forward(**sample["net_input"]) for item in extra_out.attn_list: p_choose = item["p_choose"] alpha_system = item["alpha"] beta_system = item["beta"] soft_energy_system = item["soft_energy"] self.assertTrue(beta_system.size() == alpha_system.size()) self.assertTrue(p_choose.size() == alpha_system.size()) bsz, num_head, tgt_len, src_len = alpha_system.size() alpha_system = alpha_system.view(-1, tgt_len, src_len) beta_system = beta_system.view(-1, tgt_len, src_len) p_choose = p_choose.view(-1, tgt_len, src_len) soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len) alpha_real = expected_alignment_formula( p_choose, self.model.decoder.layers[0].encoder_attn.mass_preservation, sample["net_input"]["src_tokens"].eq(PAD_INDEX) ) beta_real = expected_soft_attention_formula( alpha_real, soft_energy_system, sample["net_input"]["src_tokens"].eq(PAD_INDEX), chunksize=getattr( self.model.decoder.layers[0].encoder_attn, "chunk_size", int(1e10) ) ) self.assertTrue( torch.abs(beta_system - beta_real).le(1e-5).all(), ) class ChunkwiswTestCase( InfiniteLookbackTestCase ): def setUp(self): self.model = build_transformer_monotonic_attention( **generate_config( { "simul_type": "chunkwise", "mocha_chunk_size": 3 } ) ) class WaitkTestCase(InfiniteLookbackTestCase): def setUp(self): self.model = build_transformer_monotonic_attention( **generate_config( { "simul_type": "waitk", "waitk_lagging": 3, } ) ) def check_waitk(self, p_choose, lagging, padding_mask): bsz, tgt_len, src_len = p_choose.size() for bsz_i in range(bsz): for i in range(tgt_len): for j in range(src_len): if not padding_mask[bsz_i, j]: if j - i == lagging - 1: self.assertTrue(p_choose[bsz_i, i, j] == 1) else: self.assertTrue(p_choose[bsz_i, i, j] == 0) def test_waitk_p_choose(self): for longer_src in [True, False]: for k in [1, 3, 10, 20, 100]: sample = make_sample_with_padding(longer_src) model = build_transformer_monotonic_attention( **generate_config( { "simul_type": "waitk", "waitk_lagging": k, } ) ) model.train() _, extra_out = model.forward(**sample["net_input"]) for item in extra_out.attn_list: p_choose = item["p_choose"] bsz, num_heads, tgt_len, src_len = p_choose.size() padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX) padding_mask = ( padding_mask .unsqueeze(1) .expand([bsz, num_heads, src_len]) .contiguous() .view(-1, src_len) ) p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len) self.check_waitk(p_choose, k, padding_mask)