import unittest from dataset import SUPPORTED_SUMM_DATASETS, list_all_datasets from dataset.st_dataset import SummDataset, SummInstance from dataset.dataset_loaders import ArxivDataset from helpers import print_with_color class TestDatasets(unittest.TestCase): def _test_instance( self, ins: SummInstance, is_query: bool = False, is_multi_document: bool = False, is_dialogue: bool = False, ): if is_multi_document or is_dialogue: self.assertTrue(isinstance(ins.source, list)) else: self.assertTrue(isinstance(ins.source, list) or isinstance(ins.source, str)) if is_query: self.assertTrue(isinstance(ins.query, str)) def test_all_datasets(self): print_with_color(f"{'#' * 10} Testing all datasets... {'#' * 10}\n\n", "35") print(list_all_datasets()) num_datasets = 0 for ds_cls in SUPPORTED_SUMM_DATASETS: # TODO: Temporarily skipping Arxiv (size/time), > 30min download time for Travis-CI if ds_cls in [ArxivDataset]: continue print_with_color(f"Testing {ds_cls} dataset...", "35") ds: SummDataset = ds_cls() ds.show_description() # must have at least one of train/dev/test set assert ds.train_set or ds.validation_set or ds.test_set if ds.train_set is not None: train_set = list(ds.train_set) print(f"{ds_cls} has a training set of {len(train_set)} examples") self._test_instance( train_set[0], is_multi_document=ds.is_multi_document, is_dialogue=ds.is_dialogue_based, ) if ds.validation_set is not None: val_set = list(ds.validation_set) print(f"{ds_cls} has a validation set of {len(val_set)} examples") self._test_instance( val_set[0], is_multi_document=ds.is_multi_document, is_dialogue=ds.is_dialogue_based, ) if ds.test_set is not None: test_set = list(ds.test_set) print(f"{ds_cls} has a test set of {len(test_set)} examples") self._test_instance( test_set[0], is_multi_document=ds.is_multi_document, is_dialogue=ds.is_dialogue_based, ) print_with_color(f"{ds.dataset_name} dataset test complete\n", "32") num_datasets += 1 print_with_color( f"{'#' * 10} test_all_datasets {__name__} complete ({num_datasets} datasets) {'#' * 10}", "32", ) if __name__ == "__main__": unittest.main()