Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import argparse | |
| import os | |
| import unittest | |
| from inspect import currentframe, getframeinfo | |
| import numpy as np | |
| import torch | |
| from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask | |
| from fairseq.data import data_utils as fairseq_data_utils | |
| from fairseq.data.dictionary import Dictionary | |
| from fairseq.models import ( | |
| BaseFairseqModel, | |
| FairseqDecoder, | |
| FairseqEncoder, | |
| FairseqEncoderDecoderModel, | |
| FairseqEncoderModel, | |
| FairseqModel, | |
| ) | |
| from fairseq.tasks.fairseq_task import LegacyFairseqTask | |
| DEFAULT_TEST_VOCAB_SIZE = 100 | |
| # /////////////////////////////////////////////////////////////////////////// | |
| # utility function to setup dummy dict/task/input | |
| # /////////////////////////////////////////////////////////////////////////// | |
| def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): | |
| dummy_dict = Dictionary() | |
| # add dummy symbol to satisfy vocab size | |
| for id, _ in enumerate(range(vocab_size)): | |
| dummy_dict.add_symbol("{}".format(id), 1000) | |
| return dummy_dict | |
| class DummyTask(LegacyFairseqTask): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.dictionary = get_dummy_dictionary() | |
| if getattr(self.args, "ctc", False): | |
| self.dictionary.add_symbol("<ctc_blank>") | |
| self.tgt_dict = self.dictionary | |
| def target_dictionary(self): | |
| return self.dictionary | |
| def get_dummy_task_and_parser(): | |
| """ | |
| to build a fariseq model, we need some dummy parse and task. This function | |
| is used to create dummy task and parser to faciliate model/criterion test | |
| Note: we use FbSpeechRecognitionTask as the dummy task. You may want | |
| to use other task by providing another function | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS | |
| ) | |
| DummyTask.add_args(parser) | |
| args = parser.parse_args([]) | |
| task = DummyTask.setup_task(args) | |
| return task, parser | |
| def get_dummy_input(T=100, D=80, B=5, K=100): | |
| forward_input = {} | |
| # T max sequence length | |
| # D feature vector dimension | |
| # B batch size | |
| # K target dimension size | |
| feature = torch.randn(B, T, D) | |
| # this (B, T, D) layout is just a convention, you can override it by | |
| # write your own _prepare_forward_input function | |
| src_lengths = torch.from_numpy( | |
| np.random.randint(low=1, high=T, size=B, dtype=np.int64) | |
| ) | |
| src_lengths[0] = T # make sure the maximum length matches | |
| prev_output_tokens = [] | |
| for b in range(B): | |
| token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1) | |
| tokens = np.random.randint(low=0, high=K, size=token_length, dtype=np.int64) | |
| prev_output_tokens.append(torch.from_numpy(tokens)) | |
| prev_output_tokens = fairseq_data_utils.collate_tokens( | |
| prev_output_tokens, | |
| pad_idx=1, | |
| eos_idx=2, | |
| left_pad=False, | |
| move_eos_to_beginning=False, | |
| ) | |
| src_lengths, sorted_order = src_lengths.sort(descending=True) | |
| forward_input["src_tokens"] = feature.index_select(0, sorted_order) | |
| forward_input["src_lengths"] = src_lengths | |
| forward_input["prev_output_tokens"] = prev_output_tokens | |
| return forward_input | |
| def get_dummy_encoder_output(encoder_out_shape=(100, 80, 5)): | |
| """ | |
| This only provides an example to generate dummy encoder output | |
| """ | |
| (T, B, D) = encoder_out_shape | |
| encoder_out = {} | |
| encoder_out["encoder_out"] = torch.from_numpy( | |
| np.random.randn(*encoder_out_shape).astype(np.float32) | |
| ) | |
| seq_lengths = torch.from_numpy(np.random.randint(low=1, high=T, size=B)) | |
| # some dummy mask | |
| encoder_out["encoder_padding_mask"] = torch.arange(T).view(1, T).expand( | |
| B, -1 | |
| ) >= seq_lengths.view(B, 1).expand(-1, T) | |
| encoder_out["encoder_padding_mask"].t_() | |
| # encoer_padding_mask is (T, B) tensor, with (t, b)-th element indicate | |
| # whether encoder_out[t, b] is valid (=0) or not (=1) | |
| return encoder_out | |
| def _current_postion_info(): | |
| cf = currentframe() | |
| frameinfo = " (at {}:{})".format( | |
| os.path.basename(getframeinfo(cf).filename), cf.f_back.f_lineno | |
| ) | |
| return frameinfo | |
| def check_encoder_output(encoder_output, batch_size=None): | |
| """we expect encoder_output to be a dict with the following | |
| key/value pairs: | |
| - encoder_out: a Torch.Tensor | |
| - encoder_padding_mask: a binary Torch.Tensor | |
| """ | |
| if not isinstance(encoder_output, dict): | |
| msg = ( | |
| "FairseqEncoderModel.forward(...) must be a dict" + _current_postion_info() | |
| ) | |
| return False, msg | |
| if "encoder_out" not in encoder_output: | |
| msg = ( | |
| "FairseqEncoderModel.forward(...) must contain encoder_out" | |
| + _current_postion_info() | |
| ) | |
| return False, msg | |
| if "encoder_padding_mask" not in encoder_output: | |
| msg = ( | |
| "FairseqEncoderModel.forward(...) must contain encoder_padding_mask" | |
| + _current_postion_info() | |
| ) | |
| return False, msg | |
| if not isinstance(encoder_output["encoder_out"], torch.Tensor): | |
| msg = "encoder_out must be a torch.Tensor" + _current_postion_info() | |
| return False, msg | |
| if encoder_output["encoder_out"].dtype != torch.float32: | |
| msg = "encoder_out must have float32 dtype" + _current_postion_info() | |
| return False, msg | |
| mask = encoder_output["encoder_padding_mask"] | |
| if mask is not None: | |
| if not isinstance(mask, torch.Tensor): | |
| msg = ( | |
| "encoder_padding_mask must be a torch.Tensor" + _current_postion_info() | |
| ) | |
| return False, msg | |
| if mask.dtype != torch.uint8 and ( | |
| not hasattr(torch, "bool") or mask.dtype != torch.bool | |
| ): | |
| msg = ( | |
| "encoder_padding_mask must have dtype of uint8" | |
| + _current_postion_info() | |
| ) | |
| return False, msg | |
| if mask.dim() != 2: | |
| msg = ( | |
| "we expect encoder_padding_mask to be a 2-d tensor, in shape (T, B)" | |
| + _current_postion_info() | |
| ) | |
| return False, msg | |
| if batch_size is not None and mask.size(1) != batch_size: | |
| msg = ( | |
| "we expect encoder_padding_mask to be a 2-d tensor, with size(1)" | |
| + " being the batch size" | |
| + _current_postion_info() | |
| ) | |
| return False, msg | |
| return True, None | |
| def check_decoder_output(decoder_output): | |
| """we expect output from a decoder is a tuple with the following constraint: | |
| - the first element is a torch.Tensor | |
| - the second element can be anything (reserved for future use) | |
| """ | |
| if not isinstance(decoder_output, tuple): | |
| msg = "FariseqDecoder output must be a tuple" + _current_postion_info() | |
| return False, msg | |
| if len(decoder_output) != 2: | |
| msg = "FairseqDecoder output must be 2-elem tuple" + _current_postion_info() | |
| return False, msg | |
| if not isinstance(decoder_output[0], torch.Tensor): | |
| msg = ( | |
| "FariseqDecoder output[0] must be a torch.Tensor" + _current_postion_info() | |
| ) | |
| return False, msg | |
| return True, None | |
| # /////////////////////////////////////////////////////////////////////////// | |
| # Base Test class | |
| # /////////////////////////////////////////////////////////////////////////// | |
| class TestBaseFairseqModelBase(unittest.TestCase): | |
| """ | |
| This class is used to facilitate writing unittest for any class derived from | |
| `BaseFairseqModel`. | |
| """ | |
| def setUpClass(cls): | |
| if cls is TestBaseFairseqModelBase: | |
| raise unittest.SkipTest("Skipping test case in base") | |
| super().setUpClass() | |
| def setUpModel(self, model): | |
| self.assertTrue(isinstance(model, BaseFairseqModel)) | |
| self.model = model | |
| def setupInput(self): | |
| pass | |
| def setUp(self): | |
| self.model = None | |
| self.forward_input = None | |
| pass | |
| class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase): | |
| """ | |
| base code to test FairseqEncoderDecoderModel (formally known as | |
| `FairseqModel`) must be derived from this base class | |
| """ | |
| def setUpClass(cls): | |
| if cls is TestFairseqEncoderDecoderModelBase: | |
| raise unittest.SkipTest("Skipping test case in base") | |
| super().setUpClass() | |
| def setUpModel(self, model_cls, extra_args_setters=None): | |
| self.assertTrue( | |
| issubclass(model_cls, (FairseqEncoderDecoderModel, FairseqModel)), | |
| msg="This class only tests for FairseqModel subclasses", | |
| ) | |
| task, parser = get_dummy_task_and_parser() | |
| model_cls.add_args(parser) | |
| args = parser.parse_args([]) | |
| if extra_args_setters is not None: | |
| for args_setter in extra_args_setters: | |
| args_setter(args) | |
| model = model_cls.build_model(args, task) | |
| self.model = model | |
| def setUpInput(self, input=None): | |
| self.forward_input = get_dummy_input() if input is None else input | |
| def setUp(self): | |
| super().setUp() | |
| def test_forward(self): | |
| if self.model and self.forward_input: | |
| forward_output = self.model.forward(**self.forward_input) | |
| # for FairseqEncoderDecoderModel, forward returns a tuple of two | |
| # elements, the first one is a Torch.Tensor | |
| succ, msg = check_decoder_output(forward_output) | |
| if not succ: | |
| self.assertTrue(succ, msg=msg) | |
| self.forward_output = forward_output | |
| def test_get_normalized_probs(self): | |
| if self.model and self.forward_input: | |
| forward_output = self.model.forward(**self.forward_input) | |
| logprob = self.model.get_normalized_probs(forward_output, log_probs=True) | |
| prob = self.model.get_normalized_probs(forward_output, log_probs=False) | |
| # in order for different models/criterion to play with each other | |
| # we need to know whether the logprob or prob output is batch_first | |
| # or not. We assume an additional attribute will be attached to logprob | |
| # or prob. If you find your code failed here, simply override | |
| # FairseqModel.get_normalized_probs, see example at | |
| # https://fburl.com/batch_first_example | |
| self.assertTrue(hasattr(logprob, "batch_first")) | |
| self.assertTrue(hasattr(prob, "batch_first")) | |
| self.assertTrue(torch.is_tensor(logprob)) | |
| self.assertTrue(torch.is_tensor(prob)) | |
| class TestFairseqEncoderModelBase(TestBaseFairseqModelBase): | |
| """ | |
| base class to test FairseqEncoderModel | |
| """ | |
| def setUpClass(cls): | |
| if cls is TestFairseqEncoderModelBase: | |
| raise unittest.SkipTest("Skipping test case in base") | |
| super().setUpClass() | |
| def setUpModel(self, model_cls, extra_args_setters=None): | |
| self.assertTrue( | |
| issubclass(model_cls, FairseqEncoderModel), | |
| msg="This class is only used for testing FairseqEncoderModel", | |
| ) | |
| task, parser = get_dummy_task_and_parser() | |
| model_cls.add_args(parser) | |
| args = parser.parse_args([]) | |
| if extra_args_setters is not None: | |
| for args_setter in extra_args_setters: | |
| args_setter(args) | |
| model = model_cls.build_model(args, task) | |
| self.model = model | |
| def setUpInput(self, input=None): | |
| self.forward_input = get_dummy_input() if input is None else input | |
| # get_dummy_input() is originally for s2s, here we delete extra dict | |
| # items, so it can be used for EncoderModel / Encoder as well | |
| self.forward_input.pop("prev_output_tokens", None) | |
| def setUp(self): | |
| super().setUp() | |
| def test_forward(self): | |
| if self.forward_input and self.model: | |
| bsz = self.forward_input["src_tokens"].size(0) | |
| forward_output = self.model.forward(**self.forward_input) | |
| # we expect forward_output to be a dict with the following | |
| # key/value pairs: | |
| # - encoder_out: a Torch.Tensor | |
| # - encoder_padding_mask: a binary Torch.Tensor | |
| succ, msg = check_encoder_output(forward_output, batch_size=bsz) | |
| if not succ: | |
| self.assertTrue(succ, msg=msg) | |
| self.forward_output = forward_output | |
| def test_get_normalized_probs(self): | |
| if self.model and self.forward_input: | |
| forward_output = self.model.forward(**self.forward_input) | |
| logprob = self.model.get_normalized_probs(forward_output, log_probs=True) | |
| prob = self.model.get_normalized_probs(forward_output, log_probs=False) | |
| # in order for different models/criterion to play with each other | |
| # we need to know whether the logprob or prob output is batch_first | |
| # or not. We assume an additional attribute will be attached to logprob | |
| # or prob. If you find your code failed here, simply override | |
| # FairseqModel.get_normalized_probs, see example at | |
| # https://fburl.com/batch_first_example | |
| self.assertTrue(hasattr(logprob, "batch_first")) | |
| self.assertTrue(hasattr(prob, "batch_first")) | |
| self.assertTrue(torch.is_tensor(logprob)) | |
| self.assertTrue(torch.is_tensor(prob)) | |
| class TestFairseqEncoderBase(unittest.TestCase): | |
| """ | |
| base class to test FairseqEncoder | |
| """ | |
| def setUpClass(cls): | |
| if cls is TestFairseqEncoderBase: | |
| raise unittest.SkipTest("Skipping test case in base") | |
| super().setUpClass() | |
| def setUpEncoder(self, encoder): | |
| self.assertTrue( | |
| isinstance(encoder, FairseqEncoder), | |
| msg="This class is only used for test FairseqEncoder", | |
| ) | |
| self.encoder = encoder | |
| def setUpInput(self, input=None): | |
| self.forward_input = get_dummy_input() if input is None else input | |
| # get_dummy_input() is originally for s2s, here we delete extra dict | |
| # items, so it can be used for EncoderModel / Encoder as well | |
| self.forward_input.pop("prev_output_tokens", None) | |
| def setUp(self): | |
| self.encoder = None | |
| self.forward_input = None | |
| def test_forward(self): | |
| if self.encoder and self.forward_input: | |
| bsz = self.forward_input["src_tokens"].size(0) | |
| forward_output = self.encoder.forward(**self.forward_input) | |
| succ, msg = check_encoder_output(forward_output, batch_size=bsz) | |
| if not succ: | |
| self.assertTrue(succ, msg=msg) | |
| self.forward_output = forward_output | |
| class TestFairseqDecoderBase(unittest.TestCase): | |
| """ | |
| base class to test FairseqDecoder | |
| """ | |
| def setUpClass(cls): | |
| if cls is TestFairseqDecoderBase: | |
| raise unittest.SkipTest("Skipping test case in base") | |
| super().setUpClass() | |
| def setUpDecoder(self, decoder): | |
| self.assertTrue( | |
| isinstance(decoder, FairseqDecoder), | |
| msg="This class is only used for test FairseqDecoder", | |
| ) | |
| self.decoder = decoder | |
| def setUpInput(self, input=None): | |
| self.forward_input = get_dummy_encoder_output() if input is None else input | |
| def setUpPrevOutputTokens(self, tokens=None): | |
| if tokens is None: | |
| self.encoder_input = get_dummy_input() | |
| self.prev_output_tokens = self.encoder_input["prev_output_tokens"] | |
| else: | |
| self.prev_output_tokens = tokens | |
| def setUp(self): | |
| self.decoder = None | |
| self.forward_input = None | |
| self.prev_output_tokens = None | |
| def test_forward(self): | |
| if ( | |
| self.decoder is not None | |
| and self.forward_input is not None | |
| and self.prev_output_tokens is not None | |
| ): | |
| forward_output = self.decoder.forward( | |
| prev_output_tokens=self.prev_output_tokens, | |
| encoder_out=self.forward_input, | |
| ) | |
| succ, msg = check_decoder_output(forward_output) | |
| if not succ: | |
| self.assertTrue(succ, msg=msg) | |
| self.forward_input = forward_output | |
| class DummyEncoderModel(FairseqEncoderModel): | |
| def __init__(self, encoder): | |
| super().__init__(encoder) | |
| def build_model(cls, args, task): | |
| return cls(DummyEncoder()) | |
| def get_logits(self, net_output): | |
| # Inverse of sigmoid to use with BinaryCrossEntropyWithLogitsCriterion as | |
| # F.binary_cross_entropy_with_logits combines sigmoid and CE | |
| return torch.log( | |
| torch.div(net_output["encoder_out"], 1 - net_output["encoder_out"]) | |
| ) | |
| def get_normalized_probs(self, net_output, log_probs, sample=None): | |
| lprobs = super().get_normalized_probs(net_output, log_probs, sample=sample) | |
| lprobs.batch_first = True | |
| return lprobs | |
| class DummyEncoder(FairseqEncoder): | |
| def __init__(self): | |
| super().__init__(None) | |
| def forward(self, src_tokens, src_lengths): | |
| mask, max_len = lengths_to_encoder_padding_mask(src_lengths) | |
| return {"encoder_out": src_tokens, "encoder_padding_mask": mask} | |
| class CrossEntropyCriterionTestBase(unittest.TestCase): | |
| def setUpClass(cls): | |
| if cls is CrossEntropyCriterionTestBase: | |
| raise unittest.SkipTest("Skipping base class test case") | |
| super().setUpClass() | |
| def setUpArgs(self): | |
| args = argparse.Namespace() | |
| args.sentence_avg = False | |
| args.threshold = 0.1 # to use with BinaryCrossEntropyWithLogitsCriterion | |
| return args | |
| def setUp(self): | |
| args = self.setUpArgs() | |
| self.model = DummyEncoderModel(encoder=DummyEncoder()) | |
| self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args)) | |
| def get_src_tokens(self, correct_prediction, aggregate): | |
| """ | |
| correct_prediction: True if the net_output (src_tokens) should | |
| predict the correct target | |
| aggregate: True if the criterion expects net_output (src_tokens) | |
| aggregated across time axis | |
| """ | |
| predicted_idx = 0 if correct_prediction else 1 | |
| if aggregate: | |
| src_tokens = torch.zeros((2, 2), dtype=torch.float) | |
| for b in range(2): | |
| src_tokens[b][predicted_idx] = 1.0 | |
| else: | |
| src_tokens = torch.zeros((2, 10, 2), dtype=torch.float) | |
| for b in range(2): | |
| for t in range(10): | |
| src_tokens[b][t][predicted_idx] = 1.0 | |
| return src_tokens | |
| def get_target(self, soft_target): | |
| if soft_target: | |
| target = torch.zeros((2, 2), dtype=torch.float) | |
| for b in range(2): | |
| target[b][0] = 1.0 | |
| else: | |
| target = torch.zeros((2, 10), dtype=torch.long) | |
| return target | |
| def get_test_sample(self, correct, soft_target, aggregate): | |
| src_tokens = self.get_src_tokens(correct, aggregate) | |
| target = self.get_target(soft_target) | |
| L = src_tokens.size(1) | |
| return { | |
| "net_input": {"src_tokens": src_tokens, "src_lengths": torch.tensor([L])}, | |
| "target": target, | |
| "ntokens": src_tokens.size(0) * src_tokens.size(1), | |
| } | |