|
import argparse |
|
import unittest |
|
from typing import Any, Dict, Sequence |
|
|
|
import torch |
|
from fairseq.models import transformer |
|
|
|
from tests.test_roberta import FakeTask |
|
|
|
|
|
def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]: |
|
if not tok: |
|
tok = [10, 11, 12, 13, 14, 15, 2] |
|
|
|
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size) |
|
sample = { |
|
"net_input": { |
|
"src_tokens": batch, |
|
"prev_output_tokens": batch, |
|
"src_lengths": torch.tensor( |
|
[len(tok)] * batch_size, dtype=torch.long, device=batch.device |
|
), |
|
}, |
|
"target": batch[:, 1:], |
|
} |
|
return sample |
|
|
|
|
|
def mk_transformer(**extra_args: Any): |
|
overrides = { |
|
|
|
"encoder_embed_dim": 12, |
|
"encoder_ffn_embed_dim": 14, |
|
"decoder_embed_dim": 12, |
|
"decoder_ffn_embed_dim": 14, |
|
|
|
"dropout": 0, |
|
"attention_dropout": 0, |
|
"activation_dropout": 0, |
|
"encoder_layerdrop": 0, |
|
} |
|
overrides.update(extra_args) |
|
|
|
args = argparse.Namespace(**overrides) |
|
transformer.tiny_architecture(args) |
|
|
|
torch.manual_seed(0) |
|
task = FakeTask(args) |
|
return transformer.TransformerModel.build_model(args, task) |
|
|
|
|
|
class TransformerTestCase(unittest.TestCase): |
|
def test_forward_backward(self): |
|
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12) |
|
sample = mk_sample() |
|
o, _ = model.forward(**sample["net_input"]) |
|
loss = o.sum() |
|
loss.backward() |
|
|
|
def test_different_encoder_decoder_embed_dim(self): |
|
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16) |
|
sample = mk_sample() |
|
o, _ = model.forward(**sample["net_input"]) |
|
loss = o.sum() |
|
loss.backward() |
|
|