SummerTime / tests /integration_test.py
aliabd
full demo working with old graido
7e3e85d
raw history blame
No virus
5.26 kB
import unittest
from model.base_model import SummModel
from model import SUPPORTED_SUMM_MODELS
from pipeline import assemble_model_pipeline
from evaluation.base_metric import SummMetric
from evaluation import SUPPORTED_EVALUATION_METRICS
from dataset.st_dataset import SummInstance, SummDataset
from dataset import SUPPORTED_SUMM_DATASETS
from dataset.dataset_loaders import ScisummnetDataset, ArxivDataset
from helpers import print_with_color, retrieve_random_test_instances
import random
import time
from typing import List, Union, Tuple
import sys
import re
class IntegrationTests(unittest.TestCase):
def get_prediction(
self, model: SummModel, dataset: SummDataset, test_instances: List[SummInstance]
) -> Tuple[Union[List[str], List[List[str]]], Union[List[str], List[List[str]]]]:
"""
Get summary prediction given model and dataset instances.
:param SummModel `model`: Model for summarization task.
:param SummDataset `dataset`: Dataset for summarization task.
:param List[SummInstance] `test_instances`: Instances from `dataset` to summarize.
:returns Tuple containing summary list of summary predictions and targets corresponding to each instance in `test_instances`.
"""
src = (
[ins.source[0] for ins in test_instances]
if isinstance(dataset, ScisummnetDataset)
else [ins.source for ins in test_instances]
)
tgt = [ins.summary for ins in test_instances]
query = (
[ins.query for ins in test_instances] if dataset.is_query_based else None
)
prediction = model.summarize(src, query)
return prediction, tgt
def get_eval_dict(self, metric: SummMetric, prediction: List[str], tgt: List[str]):
"""
Run evaluation metric on summary prediction.
:param SummMetric `metric`: Evaluation metric.
:param List[str] `prediction`: Summary prediction instances.
:param List[str] `tgt`: Target prediction instances from dataset.
"""
score_dict = metric.evaluate(prediction, tgt)
return score_dict
def test_all(self):
"""
Runs integration test on all compatible dataset + model + evaluation metric pipelines supported by SummerTime.
"""
print_with_color("\nInitializing all evaluation metrics...", "35")
evaluation_metrics = []
for eval_cls in SUPPORTED_EVALUATION_METRICS:
# # TODO: Temporarily skipping Rouge/RougeWE metrics to avoid local bug.
# if eval_cls in [Rouge, RougeWe]:
# continue
print(eval_cls)
evaluation_metrics.append(eval_cls())
print_with_color("\n\nBeginning integration tests...", "35")
for dataset_cls in SUPPORTED_SUMM_DATASETS:
# TODO: Temporarily skipping Arxiv (size/time)
if dataset_cls in [ArxivDataset]:
continue
dataset = dataset_cls()
if dataset.train_set is not None:
dataset_instances = list(dataset.train_set)
print(
f"\n{dataset.dataset_name} has a training set of {len(dataset_instances)} examples"
)
print_with_color(
f"Initializing all matching model pipelines for {dataset.dataset_name} dataset...",
"35",
)
# matching_model_instances = assemble_model_pipeline(dataset_cls, list(filter(lambda m: m != PegasusModel, SUPPORTED_SUMM_MODELS)))
matching_model_instances = assemble_model_pipeline(
dataset_cls, SUPPORTED_SUMM_MODELS
)
for model, model_name in matching_model_instances:
test_instances = retrieve_random_test_instances(
dataset_instances=dataset_instances, num_instances=1
)
print_with_color(
f"{'#' * 20} Testing: {dataset.dataset_name} dataset, {model_name} model {'#' * 20}",
"35",
)
prediction, tgt = self.get_prediction(
model, dataset, test_instances
)
print(f"Prediction: {prediction}\nTarget: {tgt}\n")
for metric in evaluation_metrics:
print_with_color(f"{metric.metric_name} metric", "35")
score_dict = self.get_eval_dict(metric, prediction, tgt)
print(score_dict)
print_with_color(
f"{'#' * 20} Test for {dataset.dataset_name} dataset, {model_name} model COMPLETE {'#' * 20}\n\n",
"32",
)
if __name__ == "__main__":
if len(sys.argv) > 2 or (
len(sys.argv) == 2 and not re.match("^\\d+$", sys.argv[1])
):
print("Usage: python tests/integration_test.py [seed]", file=sys.stderr)
sys.exit(1)
seed = int(time.time()) if len(sys.argv) == 1 else int(sys.argv.pop())
random.seed(seed)
print_with_color(f"(to reproduce) random seeded with {seed}\n", "32")
unittest.main()