SummerTime / tests /helpers.py
aliabd
full demo working with old graido
7e3e85d
raw
history blame
1.91 kB
from dataset.st_dataset import SummDataset, SummInstance
import random
from typing import List, Tuple
def print_with_color(s: str, color: str):
"""
Print formatted string.
:param str `s`: String to print.
:param str `color`: ANSI color code.
:see https://gist.github.com/RabaDabaDoba/145049536f815903c79944599c6f952a
"""
print(f"\033[{color}m{s}\033[0m")
def retrieve_random_test_instances(
dataset_instances: List[SummInstance], num_instances=3
) -> List[SummInstance]:
"""
Retrieve random test instances from a dataset training set.
:param List[SummInstance] `dataset_instances`: Instances from a dataset `train_set` to pull random examples from.
:param int `num_instances`: Number of random instances to pull. Defaults to `3`.
:return List of SummInstance to summarize.
"""
test_instances = []
for i in range(num_instances):
test_instances.append(
dataset_instances[random.randint(0, len(dataset_instances) - 1)]
)
return test_instances
def get_summarization_set(dataset: SummDataset, size=1) -> Tuple[List, List]:
"""
Return instances from given summarization dataset, in the format of (sources, targets).
"""
subset = []
for i in range(size):
subset.append(next(dataset.train_set))
src, tgt = zip(*(list(map(lambda x: (x.source, x.summary), subset))))
return list(src), list(tgt)
def get_query_based_summarization_set(
dataset: SummDataset, size=1
) -> Tuple[List, List, List]:
"""
Return instances from given query-based summarization dataset, in the format of (sources, targets, queries).
"""
subset = []
for i in range(size):
subset.append(next(dataset.train_set))
src, tgt, queries = zip(
*(list(map(lambda x: (x.source, x.summary, x.query), subset)))
)
return list(src), list(tgt), list(queries)