SummerTime / tests /model_test.py
aliabd
full demo working with old graido
7e3e85d
raw history blame
No virus
4.42 kB
import unittest
from typing import List
from dataset.dataset_loaders import CnndmDataset, MultinewsDataset, PubmedqaDataset
from model import SUPPORTED_SUMM_MODELS, list_all_models
from model.single_doc import LexRankModel, LongformerModel
from model.dialogue import HMNetModel
from helpers import (
print_with_color,
get_summarization_set,
get_query_based_summarization_set,
)
class TestModels(unittest.TestCase):
single_doc_dataset = CnndmDataset()
multi_doc_dataset = MultinewsDataset()
query_based_dataset = PubmedqaDataset()
# # TODO: temporarily skipping HMNet, no dialogue-based dataset needed
# dialogue_based_dataset = SamsumDataset()
def test_list_models(self):
print_with_color(f"{'#'*10} Testing test_list_models... {'#'*10}\n", "35")
all_models = list_all_models()
for model_class, model_description in all_models:
print(f"{model_class} : {model_description}")
self.assertTrue(True)
self.assertEqual(len(all_models), len(SUPPORTED_SUMM_MODELS))
print_with_color(
f"{'#'*10} test_list_models {__name__} test complete {'#'*10}\n\n", "32"
)
def validate_prediction(self, prediction: List[str], src: List):
"""
Verify that prediction instances match source instances.
"""
self.assertTrue(isinstance(prediction, list))
self.assertTrue(all([isinstance(ins, str) for ins in prediction]))
self.assertTrue(len(prediction) == len(src))
print("Prediction typing and length matches source instances!")
def test_model_summarize(self):
"""
Test all supported models on instances from datasets.
"""
print_with_color(f"{'#'*10} Testing all models... {'#'*10}\n", "35")
num_models = 0
all_models = list_all_models()
for model_class, _ in all_models:
if model_class in [HMNetModel]:
# TODO: Temporarily skip HMNet (requires large pre-trained model download + GPU)
continue
print_with_color(f"Testing {model_class.model_name} model...", "35")
if model_class == LexRankModel:
# current LexRankModel requires a training set
training_src, training_tgt = get_summarization_set(
self.single_doc_dataset, 100
)
model = model_class(training_src)
else:
model = model_class()
if model.is_query_based:
test_src, test_tgt, test_query = get_query_based_summarization_set(
self.query_based_dataset, 1
)
prediction = model.summarize(test_src, test_query)
print(
f"Query: {test_query}\nGold summary: {test_tgt}\nPredicted summary: {prediction}"
)
elif model.is_multi_document:
test_src, test_tgt = get_summarization_set(self.multi_doc_dataset, 1)
prediction = model.summarize(test_src)
print(f"Gold summary: {test_tgt} \nPredicted summary: {prediction}")
self.validate_prediction(prediction, test_src)
elif model.is_dialogue_based:
test_src, test_tgt = get_summarization_set(
self.dialogue_based_dataset, 1
)
prediction = model.summarize(test_src)
print(f"Gold summary: {test_tgt}\nPredicted summary: {prediction}")
self.validate_prediction(prediction, test_src)
else:
test_src, test_tgt = get_summarization_set(self.single_doc_dataset, 1)
prediction = model.summarize(
[test_src[0] * 5] if model_class == LongformerModel else test_src
)
print(f"Gold summary: {test_tgt} \nPredicted summary: {prediction}")
self.validate_prediction(
prediction,
[test_src[0] * 5] if model_class == LongformerModel else test_src,
)
print_with_color(f"{model_class.model_name} model test complete\n", "32")
num_models += 1
print_with_color(
f"{'#'*10} test_model_summarize complete ({num_models} models) {'#'*10}\n",
"32",
)
if __name__ == "__main__":
unittest.main()