Spaces:
Build error
File size: 2,830 Bytes
7e3e85d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import unittest
from dataset import SUPPORTED_SUMM_DATASETS, list_all_datasets
from dataset.st_dataset import SummDataset, SummInstance
from dataset.dataset_loaders import ArxivDataset
from helpers import print_with_color
class TestDatasets(unittest.TestCase):
def _test_instance(
self,
ins: SummInstance,
is_query: bool = False,
is_multi_document: bool = False,
is_dialogue: bool = False,
):
if is_multi_document or is_dialogue:
self.assertTrue(isinstance(ins.source, list))
else:
self.assertTrue(isinstance(ins.source, list) or isinstance(ins.source, str))
if is_query:
self.assertTrue(isinstance(ins.query, str))
def test_all_datasets(self):
print_with_color(f"{'#' * 10} Testing all datasets... {'#' * 10}\n\n", "35")
print(list_all_datasets())
num_datasets = 0
for ds_cls in SUPPORTED_SUMM_DATASETS:
# TODO: Temporarily skipping Arxiv (size/time), > 30min download time for Travis-CI
if ds_cls in [ArxivDataset]:
continue
print_with_color(f"Testing {ds_cls} dataset...", "35")
ds: SummDataset = ds_cls()
ds.show_description()
# must have at least one of train/dev/test set
assert ds.train_set or ds.validation_set or ds.test_set
if ds.train_set is not None:
train_set = list(ds.train_set)
print(f"{ds_cls} has a training set of {len(train_set)} examples")
self._test_instance(
train_set[0],
is_multi_document=ds.is_multi_document,
is_dialogue=ds.is_dialogue_based,
)
if ds.validation_set is not None:
val_set = list(ds.validation_set)
print(f"{ds_cls} has a validation set of {len(val_set)} examples")
self._test_instance(
val_set[0],
is_multi_document=ds.is_multi_document,
is_dialogue=ds.is_dialogue_based,
)
if ds.test_set is not None:
test_set = list(ds.test_set)
print(f"{ds_cls} has a test set of {len(test_set)} examples")
self._test_instance(
test_set[0],
is_multi_document=ds.is_multi_document,
is_dialogue=ds.is_dialogue_based,
)
print_with_color(f"{ds.dataset_name} dataset test complete\n", "32")
num_datasets += 1
print_with_color(
f"{'#' * 10} test_all_datasets {__name__} complete ({num_datasets} datasets) {'#' * 10}",
"32",
)
if __name__ == "__main__":
unittest.main()
|