Spaces:
Runtime error
Runtime error
import os | |
from collections import deque | |
import torch | |
from torch.utils.data import Dataset | |
# ------------ | |
# Data loading | |
# ------------ | |
class CNNDMDataset(Dataset): | |
"""Abstracts the dataset used to train seq2seq models. | |
The class will process the documents that are located in the specified | |
folder. The preprocessing will work on any document that is reasonably | |
formatted. On the CNN/DailyMail dataset it will extract both the story | |
and the summary. | |
CNN/Daily News: | |
The CNN/Daily News raw datasets are downloaded from [1]. The stories are | |
stored in different files; the summary appears at the end of the story as | |
sentences that are prefixed by the special `@highlight` line. To process | |
the data, untar both datasets in the same folder, and pass the path to this | |
folder as the "data_dir argument. The formatting code was inspired by [2]. | |
[1] https://cs.nyu.edu/~kcho/ | |
[2] https://github.com/abisee/cnn-dailymail/ | |
""" | |
def __init__(self, path="", prefix="train"): | |
"""We initialize the class by listing all the documents to summarize. | |
Files are not read in memory due to the size of some datasets (like CNN/DailyMail). | |
""" | |
assert os.path.isdir(path) | |
self.documents = [] | |
story_filenames_list = os.listdir(path) | |
for story_filename in story_filenames_list: | |
if "summary" in story_filename: | |
continue | |
path_to_story = os.path.join(path, story_filename) | |
if not os.path.isfile(path_to_story): | |
continue | |
self.documents.append(path_to_story) | |
def __len__(self): | |
"""Returns the number of documents.""" | |
return len(self.documents) | |
def __getitem__(self, idx): | |
document_path = self.documents[idx] | |
document_name = document_path.split("/")[-1] | |
with open(document_path, encoding="utf-8") as source: | |
raw_story = source.read() | |
story_lines, summary_lines = process_story(raw_story) | |
return document_name, story_lines, summary_lines | |
def process_story(raw_story): | |
"""Extract the story and summary from a story file. | |
Arguments: | |
raw_story (str): content of the story file as an utf-8 encoded string. | |
Raises: | |
IndexError: If the story is empty or contains no highlights. | |
""" | |
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])) | |
# for some unknown reason some lines miss a period, add it | |
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] | |
# gather article lines | |
story_lines = [] | |
lines = deque(nonempty_lines) | |
while True: | |
try: | |
element = lines.popleft() | |
if element.startswith("@highlight"): | |
break | |
story_lines.append(element) | |
except IndexError: | |
# if "@highlight" is absent from the file we pop | |
# all elements until there is None, raising an exception. | |
return story_lines, [] | |
# gather summary lines | |
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) | |
return story_lines, summary_lines | |
def _add_missing_period(line): | |
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"] | |
if line.startswith("@highlight"): | |
return line | |
if line[-1] in END_TOKENS: | |
return line | |
return line + "." | |
# -------------------------- | |
# Encoding and preprocessing | |
# -------------------------- | |
def truncate_or_pad(sequence, block_size, pad_token_id): | |
"""Adapt the source and target sequences' lengths to the block size. | |
If the sequence is shorter we append padding token to the right of the sequence. | |
""" | |
if len(sequence) > block_size: | |
return sequence[:block_size] | |
else: | |
sequence.extend([pad_token_id] * (block_size - len(sequence))) | |
return sequence | |
def build_mask(sequence, pad_token_id): | |
"""Builds the mask. The attention mechanism will only attend to positions | |
with value 1.""" | |
mask = torch.ones_like(sequence) | |
idx_pad_tokens = sequence == pad_token_id | |
mask[idx_pad_tokens] = 0 | |
return mask | |
def encode_for_summarization(story_lines, summary_lines, tokenizer): | |
"""Encode the story and summary lines, and join them | |
as specified in [1] by using `[SEP] [CLS]` tokens to separate | |
sentences. | |
""" | |
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines] | |
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence] | |
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines] | |
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence] | |
return story_token_ids, summary_token_ids | |
def compute_token_type_ids(batch, separator_token_id): | |
"""Segment embeddings as described in [1] | |
The values {0,1} were found in the repository [2]. | |
Attributes: | |
batch: torch.Tensor, size [batch_size, block_size] | |
Batch of input. | |
separator_token_id: int | |
The value of the token that separates the segments. | |
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." | |
arXiv preprint arXiv:1908.08345 (2019). | |
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) | |
""" | |
batch_embeddings = [] | |
for sequence in batch: | |
sentence_num = -1 | |
embeddings = [] | |
for s in sequence: | |
if s == separator_token_id: | |
sentence_num += 1 | |
embeddings.append(sentence_num % 2) | |
batch_embeddings.append(embeddings) | |
return torch.tensor(batch_embeddings) | |