# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest from argparse import ArgumentParser from dataclasses import dataclass, field from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass @dataclass class A(FairseqDataclass): data: str = field(default="test", metadata={"help": "the data input"}) num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) @dataclass class B(FairseqDataclass): bar: A = field(default=A()) foo: int = field(default=0, metadata={"help": "not a bar"}) @dataclass class D(FairseqDataclass): arch: A = field(default=A()) foo: int = field(default=0, metadata={"help": "not a bar"}) @dataclass class C(FairseqDataclass): data: str = field(default="test", metadata={"help": "root level data input"}) encoder: D = field(default=D()) decoder: A = field(default=A()) lr: int = field(default=0, metadata={"help": "learning rate"}) class TestDataclassUtils(unittest.TestCase): def test_argparse_convert_basic(self): parser = ArgumentParser() gen_parser_from_dataclass(parser, A(), True) args = parser.parse_args(["--num-layers", '10', "the/data/path"]) self.assertEqual(args.num_layers, 10) self.assertEqual(args.data, "the/data/path") def test_argparse_recursive(self): parser = ArgumentParser() gen_parser_from_dataclass(parser, B(), True) args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) self.assertEqual(args.num_layers, 10) self.assertEqual(args.foo, 10) self.assertEqual(args.data, "the/data/path") def test_argparse_recursive_prefixing(self): self.maxDiff = None parser = ArgumentParser() gen_parser_from_dataclass(parser, C(), True, "") args = parser.parse_args( [ "--encoder-arch-data", "ENCODER_ARCH_DATA", "--encoder-arch-num-layers", "10", "--encoder-foo", "10", "--decoder-data", "DECODER_DATA", "--decoder-num-layers", "10", "--lr", "10", "the/data/path", ] ) self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") self.assertEqual(args.encoder_arch_num_layers, 10) self.assertEqual(args.encoder_foo, 10) self.assertEqual(args.decoder_data, "DECODER_DATA") self.assertEqual(args.decoder_num_layers, 10) self.assertEqual(args.lr, 10) self.assertEqual(args.data, "the/data/path") if __name__ == "__main__": unittest.main()