File size: 5,261 Bytes
7e3e85d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import unittest

from model.base_model import SummModel
from model import SUPPORTED_SUMM_MODELS

from pipeline import assemble_model_pipeline

from evaluation.base_metric import SummMetric
from evaluation import SUPPORTED_EVALUATION_METRICS

from dataset.st_dataset import SummInstance, SummDataset
from dataset import SUPPORTED_SUMM_DATASETS
from dataset.dataset_loaders import ScisummnetDataset, ArxivDataset

from helpers import print_with_color, retrieve_random_test_instances

import random
import time
from typing import List, Union, Tuple
import sys
import re


class IntegrationTests(unittest.TestCase):
    def get_prediction(
        self, model: SummModel, dataset: SummDataset, test_instances: List[SummInstance]
    ) -> Tuple[Union[List[str], List[List[str]]], Union[List[str], List[List[str]]]]:
        """
        Get summary prediction given model and dataset instances.

        :param SummModel `model`: Model for summarization task.
        :param SummDataset `dataset`: Dataset for summarization task.
        :param List[SummInstance] `test_instances`: Instances from `dataset` to summarize.
        :returns Tuple containing summary list of summary predictions and targets corresponding to each instance in `test_instances`.
        """

        src = (
            [ins.source[0] for ins in test_instances]
            if isinstance(dataset, ScisummnetDataset)
            else [ins.source for ins in test_instances]
        )
        tgt = [ins.summary for ins in test_instances]
        query = (
            [ins.query for ins in test_instances] if dataset.is_query_based else None
        )
        prediction = model.summarize(src, query)
        return prediction, tgt

    def get_eval_dict(self, metric: SummMetric, prediction: List[str], tgt: List[str]):
        """
        Run evaluation metric on summary prediction.

        :param SummMetric `metric`: Evaluation metric.
        :param List[str] `prediction`: Summary prediction instances.
        :param List[str] `tgt`: Target prediction instances from dataset.
        """
        score_dict = metric.evaluate(prediction, tgt)
        return score_dict

    def test_all(self):
        """
        Runs integration test on all compatible dataset + model + evaluation metric pipelines supported by SummerTime.
        """

        print_with_color("\nInitializing all evaluation metrics...", "35")
        evaluation_metrics = []
        for eval_cls in SUPPORTED_EVALUATION_METRICS:
            # # TODO: Temporarily skipping Rouge/RougeWE metrics to avoid local bug.
            # if eval_cls in [Rouge, RougeWe]:
            #     continue
            print(eval_cls)
            evaluation_metrics.append(eval_cls())

        print_with_color("\n\nBeginning integration tests...", "35")
        for dataset_cls in SUPPORTED_SUMM_DATASETS:
            # TODO: Temporarily skipping Arxiv (size/time)
            if dataset_cls in [ArxivDataset]:
                continue
            dataset = dataset_cls()
            if dataset.train_set is not None:
                dataset_instances = list(dataset.train_set)
                print(
                    f"\n{dataset.dataset_name} has a training set of {len(dataset_instances)} examples"
                )
                print_with_color(
                    f"Initializing all matching model pipelines for {dataset.dataset_name} dataset...",
                    "35",
                )
                # matching_model_instances = assemble_model_pipeline(dataset_cls, list(filter(lambda m: m != PegasusModel, SUPPORTED_SUMM_MODELS)))
                matching_model_instances = assemble_model_pipeline(
                    dataset_cls, SUPPORTED_SUMM_MODELS
                )
                for model, model_name in matching_model_instances:
                    test_instances = retrieve_random_test_instances(
                        dataset_instances=dataset_instances, num_instances=1
                    )
                    print_with_color(
                        f"{'#' * 20} Testing: {dataset.dataset_name} dataset, {model_name} model {'#' * 20}",
                        "35",
                    )
                    prediction, tgt = self.get_prediction(
                        model, dataset, test_instances
                    )
                    print(f"Prediction: {prediction}\nTarget: {tgt}\n")
                    for metric in evaluation_metrics:
                        print_with_color(f"{metric.metric_name} metric", "35")
                        score_dict = self.get_eval_dict(metric, prediction, tgt)
                        print(score_dict)

                    print_with_color(
                        f"{'#' * 20} Test for {dataset.dataset_name} dataset, {model_name} model COMPLETE {'#' * 20}\n\n",
                        "32",
                    )


if __name__ == "__main__":
    if len(sys.argv) > 2 or (
        len(sys.argv) == 2 and not re.match("^\\d+$", sys.argv[1])
    ):
        print("Usage: python tests/integration_test.py [seed]", file=sys.stderr)
        sys.exit(1)

    seed = int(time.time()) if len(sys.argv) == 1 else int(sys.argv.pop())
    random.seed(seed)
    print_with_color(f"(to reproduce) random seeded with {seed}\n", "32")
    unittest.main()