from dataset.st_dataset import SummDataset, SummInstance import random from typing import List, Tuple def print_with_color(s: str, color: str): """ Print formatted string. :param str `s`: String to print. :param str `color`: ANSI color code. :see https://gist.github.com/RabaDabaDoba/145049536f815903c79944599c6f952a """ print(f"\033[{color}m{s}\033[0m") def retrieve_random_test_instances( dataset_instances: List[SummInstance], num_instances=3 ) -> List[SummInstance]: """ Retrieve random test instances from a dataset training set. :param List[SummInstance] `dataset_instances`: Instances from a dataset `train_set` to pull random examples from. :param int `num_instances`: Number of random instances to pull. Defaults to `3`. :return List of SummInstance to summarize. """ test_instances = [] for i in range(num_instances): test_instances.append( dataset_instances[random.randint(0, len(dataset_instances) - 1)] ) return test_instances def get_summarization_set(dataset: SummDataset, size=1) -> Tuple[List, List]: """ Return instances from given summarization dataset, in the format of (sources, targets). """ subset = [] for i in range(size): subset.append(next(dataset.train_set)) src, tgt = zip(*(list(map(lambda x: (x.source, x.summary), subset)))) return list(src), list(tgt) def get_query_based_summarization_set( dataset: SummDataset, size=1 ) -> Tuple[List, List, List]: """ Return instances from given query-based summarization dataset, in the format of (sources, targets, queries). """ subset = [] for i in range(size): subset.append(next(dataset.train_set)) src, tgt, queries = zip( *(list(map(lambda x: (x.source, x.summary, x.query), subset))) ) return list(src), list(tgt), list(queries)