import argparse import unittest from typing import Any, Dict, Sequence import torch from fairseq.models import transformer from tests.test_roberta import FakeTask def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]: if not tok: tok = [10, 11, 12, 13, 14, 15, 2] batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size) sample = { "net_input": { "src_tokens": batch, "prev_output_tokens": batch, "src_lengths": torch.tensor( [len(tok)] * batch_size, dtype=torch.long, device=batch.device ), }, "target": batch[:, 1:], } return sample def mk_transformer(**extra_args: Any): overrides = { # Use characteristics dimensions "encoder_embed_dim": 12, "encoder_ffn_embed_dim": 14, "decoder_embed_dim": 12, "decoder_ffn_embed_dim": 14, # Disable dropout so we have comparable tests. "dropout": 0, "attention_dropout": 0, "activation_dropout": 0, "encoder_layerdrop": 0, } overrides.update(extra_args) # Overrides the defaults from the parser args = argparse.Namespace(**overrides) transformer.tiny_architecture(args) torch.manual_seed(0) task = FakeTask(args) return transformer.TransformerModel.build_model(args, task) class TransformerTestCase(unittest.TestCase): def test_forward_backward(self): model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12) sample = mk_sample() o, _ = model.forward(**sample["net_input"]) loss = o.sum() loss.backward() def test_different_encoder_decoder_embed_dim(self): model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16) sample = mk_sample() o, _ = model.forward(**sample["net_input"]) loss = o.sum() loss.backward()