#!/usr/bin/env python3 import argparse import os import unittest from inspect import currentframe, getframeinfo import numpy as np import torch from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask from fairseq.data import data_utils as fairseq_data_utils from fairseq.data.dictionary import Dictionary from fairseq.models import ( BaseFairseqModel, FairseqDecoder, FairseqEncoder, FairseqEncoderDecoderModel, FairseqEncoderModel, FairseqModel, ) from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 # /////////////////////////////////////////////////////////////////////////// # utility function to setup dummy dict/task/input # /////////////////////////////////////////////////////////////////////////// def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): dummy_dict = Dictionary() # add dummy symbol to satisfy vocab size for id, _ in enumerate(range(vocab_size)): dummy_dict.add_symbol("{}".format(id), 1000) return dummy_dict class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() if getattr(self.args, "ctc", False): self.dictionary.add_symbol("") self.tgt_dict = self.dictionary @property def target_dictionary(self): return self.dictionary def get_dummy_task_and_parser(): """ to build a fariseq model, we need some dummy parse and task. This function is used to create dummy task and parser to faciliate model/criterion test Note: we use FbSpeechRecognitionTask as the dummy task. You may want to use other task by providing another function """ parser = argparse.ArgumentParser( description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS ) DummyTask.add_args(parser) args = parser.parse_args([]) task = DummyTask.setup_task(args) return task, parser def get_dummy_input(T=100, D=80, B=5, K=100): forward_input = {} # T max sequence length # D feature vector dimension # B batch size # K target dimension size feature = torch.randn(B, T, D) # this (B, T, D) layout is just a convention, you can override it by # write your own _prepare_forward_input function src_lengths = torch.from_numpy( np.random.randint(low=1, high=T, size=B, dtype=np.int64) ) src_lengths[0] = T # make sure the maximum length matches prev_output_tokens = [] for b in range(B): token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1) tokens = np.random.randint(low=0, high=K, size=token_length, dtype=np.int64) prev_output_tokens.append(torch.from_numpy(tokens)) prev_output_tokens = fairseq_data_utils.collate_tokens( prev_output_tokens, pad_idx=1, eos_idx=2, left_pad=False, move_eos_to_beginning=False, ) src_lengths, sorted_order = src_lengths.sort(descending=True) forward_input["src_tokens"] = feature.index_select(0, sorted_order) forward_input["src_lengths"] = src_lengths forward_input["prev_output_tokens"] = prev_output_tokens return forward_input def get_dummy_encoder_output(encoder_out_shape=(100, 80, 5)): """ This only provides an example to generate dummy encoder output """ (T, B, D) = encoder_out_shape encoder_out = {} encoder_out["encoder_out"] = torch.from_numpy( np.random.randn(*encoder_out_shape).astype(np.float32) ) seq_lengths = torch.from_numpy(np.random.randint(low=1, high=T, size=B)) # some dummy mask encoder_out["encoder_padding_mask"] = torch.arange(T).view(1, T).expand( B, -1 ) >= seq_lengths.view(B, 1).expand(-1, T) encoder_out["encoder_padding_mask"].t_() # encoer_padding_mask is (T, B) tensor, with (t, b)-th element indicate # whether encoder_out[t, b] is valid (=0) or not (=1) return encoder_out def _current_postion_info(): cf = currentframe() frameinfo = " (at {}:{})".format( os.path.basename(getframeinfo(cf).filename), cf.f_back.f_lineno ) return frameinfo def check_encoder_output(encoder_output, batch_size=None): """we expect encoder_output to be a dict with the following key/value pairs: - encoder_out: a Torch.Tensor - encoder_padding_mask: a binary Torch.Tensor """ if not isinstance(encoder_output, dict): msg = ( "FairseqEncoderModel.forward(...) must be a dict" + _current_postion_info() ) return False, msg if "encoder_out" not in encoder_output: msg = ( "FairseqEncoderModel.forward(...) must contain encoder_out" + _current_postion_info() ) return False, msg if "encoder_padding_mask" not in encoder_output: msg = ( "FairseqEncoderModel.forward(...) must contain encoder_padding_mask" + _current_postion_info() ) return False, msg if not isinstance(encoder_output["encoder_out"], torch.Tensor): msg = "encoder_out must be a torch.Tensor" + _current_postion_info() return False, msg if encoder_output["encoder_out"].dtype != torch.float32: msg = "encoder_out must have float32 dtype" + _current_postion_info() return False, msg mask = encoder_output["encoder_padding_mask"] if mask is not None: if not isinstance(mask, torch.Tensor): msg = ( "encoder_padding_mask must be a torch.Tensor" + _current_postion_info() ) return False, msg if mask.dtype != torch.uint8 and ( not hasattr(torch, "bool") or mask.dtype != torch.bool ): msg = ( "encoder_padding_mask must have dtype of uint8" + _current_postion_info() ) return False, msg if mask.dim() != 2: msg = ( "we expect encoder_padding_mask to be a 2-d tensor, in shape (T, B)" + _current_postion_info() ) return False, msg if batch_size is not None and mask.size(1) != batch_size: msg = ( "we expect encoder_padding_mask to be a 2-d tensor, with size(1)" + " being the batch size" + _current_postion_info() ) return False, msg return True, None def check_decoder_output(decoder_output): """we expect output from a decoder is a tuple with the following constraint: - the first element is a torch.Tensor - the second element can be anything (reserved for future use) """ if not isinstance(decoder_output, tuple): msg = "FariseqDecoder output must be a tuple" + _current_postion_info() return False, msg if len(decoder_output) != 2: msg = "FairseqDecoder output must be 2-elem tuple" + _current_postion_info() return False, msg if not isinstance(decoder_output[0], torch.Tensor): msg = ( "FariseqDecoder output[0] must be a torch.Tensor" + _current_postion_info() ) return False, msg return True, None # /////////////////////////////////////////////////////////////////////////// # Base Test class # /////////////////////////////////////////////////////////////////////////// class TestBaseFairseqModelBase(unittest.TestCase): """ This class is used to facilitate writing unittest for any class derived from `BaseFairseqModel`. """ @classmethod def setUpClass(cls): if cls is TestBaseFairseqModelBase: raise unittest.SkipTest("Skipping test case in base") super().setUpClass() def setUpModel(self, model): self.assertTrue(isinstance(model, BaseFairseqModel)) self.model = model def setupInput(self): pass def setUp(self): self.model = None self.forward_input = None pass class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase): """ base code to test FairseqEncoderDecoderModel (formally known as `FairseqModel`) must be derived from this base class """ @classmethod def setUpClass(cls): if cls is TestFairseqEncoderDecoderModelBase: raise unittest.SkipTest("Skipping test case in base") super().setUpClass() def setUpModel(self, model_cls, extra_args_setters=None): self.assertTrue( issubclass(model_cls, (FairseqEncoderDecoderModel, FairseqModel)), msg="This class only tests for FairseqModel subclasses", ) task, parser = get_dummy_task_and_parser() model_cls.add_args(parser) args = parser.parse_args([]) if extra_args_setters is not None: for args_setter in extra_args_setters: args_setter(args) model = model_cls.build_model(args, task) self.model = model def setUpInput(self, input=None): self.forward_input = get_dummy_input() if input is None else input def setUp(self): super().setUp() def test_forward(self): if self.model and self.forward_input: forward_output = self.model.forward(**self.forward_input) # for FairseqEncoderDecoderModel, forward returns a tuple of two # elements, the first one is a Torch.Tensor succ, msg = check_decoder_output(forward_output) if not succ: self.assertTrue(succ, msg=msg) self.forward_output = forward_output def test_get_normalized_probs(self): if self.model and self.forward_input: forward_output = self.model.forward(**self.forward_input) logprob = self.model.get_normalized_probs(forward_output, log_probs=True) prob = self.model.get_normalized_probs(forward_output, log_probs=False) # in order for different models/criterion to play with each other # we need to know whether the logprob or prob output is batch_first # or not. We assume an additional attribute will be attached to logprob # or prob. If you find your code failed here, simply override # FairseqModel.get_normalized_probs, see example at # https://fburl.com/batch_first_example self.assertTrue(hasattr(logprob, "batch_first")) self.assertTrue(hasattr(prob, "batch_first")) self.assertTrue(torch.is_tensor(logprob)) self.assertTrue(torch.is_tensor(prob)) class TestFairseqEncoderModelBase(TestBaseFairseqModelBase): """ base class to test FairseqEncoderModel """ @classmethod def setUpClass(cls): if cls is TestFairseqEncoderModelBase: raise unittest.SkipTest("Skipping test case in base") super().setUpClass() def setUpModel(self, model_cls, extra_args_setters=None): self.assertTrue( issubclass(model_cls, FairseqEncoderModel), msg="This class is only used for testing FairseqEncoderModel", ) task, parser = get_dummy_task_and_parser() model_cls.add_args(parser) args = parser.parse_args([]) if extra_args_setters is not None: for args_setter in extra_args_setters: args_setter(args) model = model_cls.build_model(args, task) self.model = model def setUpInput(self, input=None): self.forward_input = get_dummy_input() if input is None else input # get_dummy_input() is originally for s2s, here we delete extra dict # items, so it can be used for EncoderModel / Encoder as well self.forward_input.pop("prev_output_tokens", None) def setUp(self): super().setUp() def test_forward(self): if self.forward_input and self.model: bsz = self.forward_input["src_tokens"].size(0) forward_output = self.model.forward(**self.forward_input) # we expect forward_output to be a dict with the following # key/value pairs: # - encoder_out: a Torch.Tensor # - encoder_padding_mask: a binary Torch.Tensor succ, msg = check_encoder_output(forward_output, batch_size=bsz) if not succ: self.assertTrue(succ, msg=msg) self.forward_output = forward_output def test_get_normalized_probs(self): if self.model and self.forward_input: forward_output = self.model.forward(**self.forward_input) logprob = self.model.get_normalized_probs(forward_output, log_probs=True) prob = self.model.get_normalized_probs(forward_output, log_probs=False) # in order for different models/criterion to play with each other # we need to know whether the logprob or prob output is batch_first # or not. We assume an additional attribute will be attached to logprob # or prob. If you find your code failed here, simply override # FairseqModel.get_normalized_probs, see example at # https://fburl.com/batch_first_example self.assertTrue(hasattr(logprob, "batch_first")) self.assertTrue(hasattr(prob, "batch_first")) self.assertTrue(torch.is_tensor(logprob)) self.assertTrue(torch.is_tensor(prob)) class TestFairseqEncoderBase(unittest.TestCase): """ base class to test FairseqEncoder """ @classmethod def setUpClass(cls): if cls is TestFairseqEncoderBase: raise unittest.SkipTest("Skipping test case in base") super().setUpClass() def setUpEncoder(self, encoder): self.assertTrue( isinstance(encoder, FairseqEncoder), msg="This class is only used for test FairseqEncoder", ) self.encoder = encoder def setUpInput(self, input=None): self.forward_input = get_dummy_input() if input is None else input # get_dummy_input() is originally for s2s, here we delete extra dict # items, so it can be used for EncoderModel / Encoder as well self.forward_input.pop("prev_output_tokens", None) def setUp(self): self.encoder = None self.forward_input = None def test_forward(self): if self.encoder and self.forward_input: bsz = self.forward_input["src_tokens"].size(0) forward_output = self.encoder.forward(**self.forward_input) succ, msg = check_encoder_output(forward_output, batch_size=bsz) if not succ: self.assertTrue(succ, msg=msg) self.forward_output = forward_output class TestFairseqDecoderBase(unittest.TestCase): """ base class to test FairseqDecoder """ @classmethod def setUpClass(cls): if cls is TestFairseqDecoderBase: raise unittest.SkipTest("Skipping test case in base") super().setUpClass() def setUpDecoder(self, decoder): self.assertTrue( isinstance(decoder, FairseqDecoder), msg="This class is only used for test FairseqDecoder", ) self.decoder = decoder def setUpInput(self, input=None): self.forward_input = get_dummy_encoder_output() if input is None else input def setUpPrevOutputTokens(self, tokens=None): if tokens is None: self.encoder_input = get_dummy_input() self.prev_output_tokens = self.encoder_input["prev_output_tokens"] else: self.prev_output_tokens = tokens def setUp(self): self.decoder = None self.forward_input = None self.prev_output_tokens = None def test_forward(self): if ( self.decoder is not None and self.forward_input is not None and self.prev_output_tokens is not None ): forward_output = self.decoder.forward( prev_output_tokens=self.prev_output_tokens, encoder_out=self.forward_input, ) succ, msg = check_decoder_output(forward_output) if not succ: self.assertTrue(succ, msg=msg) self.forward_input = forward_output class DummyEncoderModel(FairseqEncoderModel): def __init__(self, encoder): super().__init__(encoder) @classmethod def build_model(cls, args, task): return cls(DummyEncoder()) def get_logits(self, net_output): # Inverse of sigmoid to use with BinaryCrossEntropyWithLogitsCriterion as # F.binary_cross_entropy_with_logits combines sigmoid and CE return torch.log( torch.div(net_output["encoder_out"], 1 - net_output["encoder_out"]) ) def get_normalized_probs(self, net_output, log_probs, sample=None): lprobs = super().get_normalized_probs(net_output, log_probs, sample=sample) lprobs.batch_first = True return lprobs class DummyEncoder(FairseqEncoder): def __init__(self): super().__init__(None) def forward(self, src_tokens, src_lengths): mask, max_len = lengths_to_encoder_padding_mask(src_lengths) return {"encoder_out": src_tokens, "encoder_padding_mask": mask} class CrossEntropyCriterionTestBase(unittest.TestCase): @classmethod def setUpClass(cls): if cls is CrossEntropyCriterionTestBase: raise unittest.SkipTest("Skipping base class test case") super().setUpClass() def setUpArgs(self): args = argparse.Namespace() args.sentence_avg = False args.threshold = 0.1 # to use with BinaryCrossEntropyWithLogitsCriterion return args def setUp(self): args = self.setUpArgs() self.model = DummyEncoderModel(encoder=DummyEncoder()) self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args)) def get_src_tokens(self, correct_prediction, aggregate): """ correct_prediction: True if the net_output (src_tokens) should predict the correct target aggregate: True if the criterion expects net_output (src_tokens) aggregated across time axis """ predicted_idx = 0 if correct_prediction else 1 if aggregate: src_tokens = torch.zeros((2, 2), dtype=torch.float) for b in range(2): src_tokens[b][predicted_idx] = 1.0 else: src_tokens = torch.zeros((2, 10, 2), dtype=torch.float) for b in range(2): for t in range(10): src_tokens[b][t][predicted_idx] = 1.0 return src_tokens def get_target(self, soft_target): if soft_target: target = torch.zeros((2, 2), dtype=torch.float) for b in range(2): target[b][0] = 1.0 else: target = torch.zeros((2, 10), dtype=torch.long) return target def get_test_sample(self, correct, soft_target, aggregate): src_tokens = self.get_src_tokens(correct, aggregate) target = self.get_target(soft_target) L = src_tokens.size(1) return { "net_input": {"src_tokens": src_tokens, "src_lengths": torch.tensor([L])}, "target": target, "ntokens": src_tokens.size(0) * src_tokens.size(1), }