Spaces:
Build error
Build error
import unittest | |
from dataset import SUPPORTED_SUMM_DATASETS, list_all_datasets | |
from dataset.st_dataset import SummDataset, SummInstance | |
from dataset.dataset_loaders import ArxivDataset | |
from helpers import print_with_color | |
class TestDatasets(unittest.TestCase): | |
def _test_instance( | |
self, | |
ins: SummInstance, | |
is_query: bool = False, | |
is_multi_document: bool = False, | |
is_dialogue: bool = False, | |
): | |
if is_multi_document or is_dialogue: | |
self.assertTrue(isinstance(ins.source, list)) | |
else: | |
self.assertTrue(isinstance(ins.source, list) or isinstance(ins.source, str)) | |
if is_query: | |
self.assertTrue(isinstance(ins.query, str)) | |
def test_all_datasets(self): | |
print_with_color(f"{'#' * 10} Testing all datasets... {'#' * 10}\n\n", "35") | |
print(list_all_datasets()) | |
num_datasets = 0 | |
for ds_cls in SUPPORTED_SUMM_DATASETS: | |
# TODO: Temporarily skipping Arxiv (size/time), > 30min download time for Travis-CI | |
if ds_cls in [ArxivDataset]: | |
continue | |
print_with_color(f"Testing {ds_cls} dataset...", "35") | |
ds: SummDataset = ds_cls() | |
ds.show_description() | |
# must have at least one of train/dev/test set | |
assert ds.train_set or ds.validation_set or ds.test_set | |
if ds.train_set is not None: | |
train_set = list(ds.train_set) | |
print(f"{ds_cls} has a training set of {len(train_set)} examples") | |
self._test_instance( | |
train_set[0], | |
is_multi_document=ds.is_multi_document, | |
is_dialogue=ds.is_dialogue_based, | |
) | |
if ds.validation_set is not None: | |
val_set = list(ds.validation_set) | |
print(f"{ds_cls} has a validation set of {len(val_set)} examples") | |
self._test_instance( | |
val_set[0], | |
is_multi_document=ds.is_multi_document, | |
is_dialogue=ds.is_dialogue_based, | |
) | |
if ds.test_set is not None: | |
test_set = list(ds.test_set) | |
print(f"{ds_cls} has a test set of {len(test_set)} examples") | |
self._test_instance( | |
test_set[0], | |
is_multi_document=ds.is_multi_document, | |
is_dialogue=ds.is_dialogue_based, | |
) | |
print_with_color(f"{ds.dataset_name} dataset test complete\n", "32") | |
num_datasets += 1 | |
print_with_color( | |
f"{'#' * 10} test_all_datasets {__name__} complete ({num_datasets} datasets) {'#' * 10}", | |
"32", | |
) | |
if __name__ == "__main__": | |
unittest.main() | |