Spaces:
Build error
Build error
from dataset.st_dataset import SummDataset, SummInstance | |
import random | |
from typing import List, Tuple | |
def print_with_color(s: str, color: str): | |
""" | |
Print formatted string. | |
:param str `s`: String to print. | |
:param str `color`: ANSI color code. | |
:see https://gist.github.com/RabaDabaDoba/145049536f815903c79944599c6f952a | |
""" | |
print(f"\033[{color}m{s}\033[0m") | |
def retrieve_random_test_instances( | |
dataset_instances: List[SummInstance], num_instances=3 | |
) -> List[SummInstance]: | |
""" | |
Retrieve random test instances from a dataset training set. | |
:param List[SummInstance] `dataset_instances`: Instances from a dataset `train_set` to pull random examples from. | |
:param int `num_instances`: Number of random instances to pull. Defaults to `3`. | |
:return List of SummInstance to summarize. | |
""" | |
test_instances = [] | |
for i in range(num_instances): | |
test_instances.append( | |
dataset_instances[random.randint(0, len(dataset_instances) - 1)] | |
) | |
return test_instances | |
def get_summarization_set(dataset: SummDataset, size=1) -> Tuple[List, List]: | |
""" | |
Return instances from given summarization dataset, in the format of (sources, targets). | |
""" | |
subset = [] | |
for i in range(size): | |
subset.append(next(dataset.train_set)) | |
src, tgt = zip(*(list(map(lambda x: (x.source, x.summary), subset)))) | |
return list(src), list(tgt) | |
def get_query_based_summarization_set( | |
dataset: SummDataset, size=1 | |
) -> Tuple[List, List, List]: | |
""" | |
Return instances from given query-based summarization dataset, in the format of (sources, targets, queries). | |
""" | |
subset = [] | |
for i in range(size): | |
subset.append(next(dataset.train_set)) | |
src, tgt, queries = zip( | |
*(list(map(lambda x: (x.source, x.summary, x.query), subset))) | |
) | |
return list(src), list(tgt), list(queries) | |