File size: 4,422 Bytes
7e3e85d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import unittest
from typing import List

from dataset.dataset_loaders import CnndmDataset, MultinewsDataset, PubmedqaDataset
from model import SUPPORTED_SUMM_MODELS, list_all_models
from model.single_doc import LexRankModel, LongformerModel
from model.dialogue import HMNetModel

from helpers import (
    print_with_color,
    get_summarization_set,
    get_query_based_summarization_set,
)


class TestModels(unittest.TestCase):

    single_doc_dataset = CnndmDataset()
    multi_doc_dataset = MultinewsDataset()
    query_based_dataset = PubmedqaDataset()
    # # TODO: temporarily skipping HMNet, no dialogue-based dataset needed
    # dialogue_based_dataset = SamsumDataset()

    def test_list_models(self):
        print_with_color(f"{'#'*10} Testing test_list_models... {'#'*10}\n", "35")
        all_models = list_all_models()
        for model_class, model_description in all_models:
            print(f"{model_class} : {model_description}")
            self.assertTrue(True)
        self.assertEqual(len(all_models), len(SUPPORTED_SUMM_MODELS))
        print_with_color(
            f"{'#'*10} test_list_models {__name__} test complete {'#'*10}\n\n", "32"
        )

    def validate_prediction(self, prediction: List[str], src: List):
        """
        Verify that prediction instances match source instances.
        """
        self.assertTrue(isinstance(prediction, list))
        self.assertTrue(all([isinstance(ins, str) for ins in prediction]))
        self.assertTrue(len(prediction) == len(src))
        print("Prediction typing and length matches source instances!")

    def test_model_summarize(self):
        """
        Test all supported models on instances from datasets.
        """

        print_with_color(f"{'#'*10} Testing all models... {'#'*10}\n", "35")

        num_models = 0
        all_models = list_all_models()

        for model_class, _ in all_models:
            if model_class in [HMNetModel]:
                # TODO: Temporarily skip HMNet (requires large pre-trained model download + GPU)
                continue

            print_with_color(f"Testing {model_class.model_name} model...", "35")

            if model_class == LexRankModel:
                # current LexRankModel requires a training set
                training_src, training_tgt = get_summarization_set(
                    self.single_doc_dataset, 100
                )
                model = model_class(training_src)
            else:
                model = model_class()

            if model.is_query_based:
                test_src, test_tgt, test_query = get_query_based_summarization_set(
                    self.query_based_dataset, 1
                )
                prediction = model.summarize(test_src, test_query)
                print(
                    f"Query: {test_query}\nGold summary: {test_tgt}\nPredicted summary: {prediction}"
                )
            elif model.is_multi_document:
                test_src, test_tgt = get_summarization_set(self.multi_doc_dataset, 1)
                prediction = model.summarize(test_src)
                print(f"Gold summary: {test_tgt} \nPredicted summary: {prediction}")
                self.validate_prediction(prediction, test_src)
            elif model.is_dialogue_based:
                test_src, test_tgt = get_summarization_set(
                    self.dialogue_based_dataset, 1
                )
                prediction = model.summarize(test_src)
                print(f"Gold summary: {test_tgt}\nPredicted summary: {prediction}")
                self.validate_prediction(prediction, test_src)
            else:
                test_src, test_tgt = get_summarization_set(self.single_doc_dataset, 1)
                prediction = model.summarize(
                    [test_src[0] * 5] if model_class == LongformerModel else test_src
                )
                print(f"Gold summary: {test_tgt} \nPredicted summary: {prediction}")
                self.validate_prediction(
                    prediction,
                    [test_src[0] * 5] if model_class == LongformerModel else test_src,
                )

            print_with_color(f"{model_class.model_name} model test complete\n", "32")
            num_models += 1

        print_with_color(
            f"{'#'*10} test_model_summarize complete ({num_models} models) {'#'*10}\n",
            "32",
        )


if __name__ == "__main__":
    unittest.main()