|
from model import SUPPORTED_SUMM_MODELS |
|
from model.base_model import SummModel |
|
from model.single_doc import LexRankModel |
|
|
|
from dataset.st_dataset import SummDataset |
|
from dataset.non_huggingface_datasets import ScisummnetDataset |
|
|
|
from typing import List, Tuple |
|
|
|
|
|
def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]: |
|
|
|
""" |
|
return some dummy summarization examples, in the format of a list of sources |
|
""" |
|
subset = [] |
|
for i in range(size): |
|
subset.append(next(iter(dataset.train_set))) |
|
|
|
src = list( |
|
map( |
|
lambda x: " ".join(x.source) |
|
if dataset.is_dialogue_based or dataset.is_multi_document |
|
else x.source[0] |
|
if isinstance(dataset, ScisummnetDataset) |
|
else x.source, |
|
subset, |
|
) |
|
) |
|
|
|
return src |
|
|
|
|
|
def assemble_model_pipeline( |
|
dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS |
|
) -> List[Tuple[SummModel, str]]: |
|
|
|
""" |
|
Return initialized list of all model pipelines that match the summarization task of given dataset. |
|
|
|
:param SummDataset `dataset`: Dataset to retrieve model pipelines for. |
|
:param List[SummModel] `model_list`: List of candidate model classes (uninitialized). Defaults to `model.SUPPORTED_SUMM_MODELS`. |
|
:returns List of tuples, where each tuple contains an initialized model and the name of that model as `(model, name)`. |
|
""" |
|
|
|
dataset = dataset if isinstance(dataset, SummDataset) else dataset() |
|
|
|
single_doc_model_list = list( |
|
filter( |
|
lambda model_cls: not ( |
|
model_cls.is_dialogue_based |
|
or model_cls.is_query_based |
|
or model_cls.is_multi_document |
|
), |
|
model_list, |
|
) |
|
) |
|
single_doc_model_instances = [ |
|
model_cls(get_lxr_train_set(dataset)) |
|
if model_cls == LexRankModel |
|
else model_cls() |
|
for model_cls in single_doc_model_list |
|
] |
|
|
|
multi_doc_model_list = list( |
|
filter(lambda model_cls: model_cls.is_multi_document, model_list) |
|
) |
|
|
|
query_based_model_list = list( |
|
filter(lambda model_cls: model_cls.is_query_based, model_list) |
|
) |
|
|
|
dialogue_based_model_list = list( |
|
filter(lambda model_cls: model_cls.is_dialogue_based, model_list) |
|
) |
|
dialogue_based_model_instances = ( |
|
[model_cls() for model_cls in dialogue_based_model_list] |
|
if dataset.is_dialogue_based |
|
else [] |
|
) |
|
|
|
matching_models = [] |
|
if dataset.is_query_based: |
|
if dataset.is_dialogue_based: |
|
for query_model_cls in query_based_model_list: |
|
for dialogue_model in dialogue_based_model_list: |
|
full_query_dialogue_model = query_model_cls( |
|
model_backend=dialogue_model |
|
) |
|
matching_models.append( |
|
( |
|
full_query_dialogue_model, |
|
f"{query_model_cls.model_name} ({dialogue_model.model_name})", |
|
) |
|
) |
|
else: |
|
for query_model_cls in query_based_model_list: |
|
for single_doc_model in single_doc_model_list: |
|
full_query_model = ( |
|
query_model_cls( |
|
model_backend=single_doc_model, |
|
data=get_lxr_train_set(dataset), |
|
) |
|
if single_doc_model == LexRankModel |
|
else query_model_cls(model_backend=single_doc_model) |
|
) |
|
matching_models.append( |
|
( |
|
full_query_model, |
|
f"{query_model_cls.model_name} ({single_doc_model.model_name})", |
|
) |
|
) |
|
return matching_models |
|
|
|
if dataset.is_multi_document: |
|
for multi_doc_model_cls in multi_doc_model_list: |
|
for single_doc_model in single_doc_model_list: |
|
full_multi_doc_model = ( |
|
multi_doc_model_cls( |
|
model_backend=single_doc_model, data=get_lxr_train_set(dataset) |
|
) |
|
if single_doc_model == LexRankModel |
|
else multi_doc_model_cls(model_backend=single_doc_model) |
|
) |
|
matching_models.append( |
|
( |
|
full_multi_doc_model, |
|
f"{multi_doc_model_cls.model_name} ({single_doc_model.model_name})", |
|
) |
|
) |
|
return matching_models |
|
|
|
if dataset.is_dialogue_based: |
|
return list( |
|
map( |
|
lambda db_model: (db_model, db_model.model_name), |
|
dialogue_based_model_instances, |
|
) |
|
) |
|
|
|
return list( |
|
map(lambda s_model: (s_model, s_model.model_name), single_doc_model_instances) |
|
) |
|
|