|
"""Module containing Dataset functionality""" |
|
|
|
import logging |
|
from typing import List |
|
|
|
import torch |
|
from datasets import IterableDataset |
|
|
|
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizedPromptDataset(IterableDataset): |
|
""" |
|
Iterable dataset that returns tokenized prompts from a stream of text files. |
|
Args: |
|
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. |
|
dataset (dataset.Dataset): Dataset with text files. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
prompt_tokenizer: PromptTokenizingStrategy, |
|
dataset: IterableDataset, |
|
): |
|
self.prompt_tokenizer = prompt_tokenizer |
|
self.dataset = dataset |
|
|
|
def __iter__(self): |
|
iterator = iter(self.dataset) |
|
|
|
for example in iterator: |
|
try: |
|
yield self.prompt_tokenizer.tokenize_prompt(example) |
|
except InvalidDataException: |
|
pass |
|
|
|
|
|
|
|
class ConstantLengthDataset(IterableDataset): |
|
""" |
|
Iterable dataset that returns constant length chunks of tokens from stream of text files. |
|
Args: |
|
tokenizer (Tokenizer): The processor used for proccessing the data. |
|
dataset (dataset.Dataset): Dataset with text files. |
|
seq_length (int): Length of token sequences to return. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
datasets, |
|
seq_length=2048, |
|
): |
|
self.tokenizer = tokenizer |
|
self.concat_token_id = tokenizer.eos_token_id |
|
self.datasets: List[IterableDataset] = datasets |
|
self.seq_length = seq_length |
|
|
|
vocab_size = len(tokenizer.get_vocab()) |
|
|
|
if vocab_size <= torch.iinfo(torch.int16).max: |
|
self.tokens_dtype = torch.int16 |
|
elif vocab_size <= torch.iinfo(torch.int32).max: |
|
self.tokens_dtype = torch.int32 |
|
else: |
|
self.tokens_dtype = torch.int64 |
|
|
|
def __iter__(self): |
|
buffer = {"input_ids": [], "attention_mask": [], "labels": []} |
|
buffer_len = 0 |
|
for dataset in self.datasets: |
|
iterator = iter(dataset) |
|
more_examples = True |
|
while more_examples: |
|
try: |
|
example = next(iterator) |
|
except StopIteration: |
|
more_examples = False |
|
example = None |
|
|
|
add_concat_token = False |
|
if example: |
|
example_len = len(example["input_ids"]) |
|
add_concat_token = example["input_ids"][-1] != self.concat_token_id |
|
else: |
|
example_len = 0 |
|
|
|
if not example_len or ( |
|
buffer_len + int(add_concat_token) + example_len > self.seq_length |
|
): |
|
if buffer["input_ids"]: |
|
input_ids = torch.cat(buffer["input_ids"], dim=-1)[ |
|
: self.seq_length |
|
] |
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ |
|
: self.seq_length |
|
] |
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] |
|
if labels.size() == input_ids.size() and ( |
|
attention_mask.size() == input_ids.size() |
|
): |
|
yield { |
|
"input_ids": input_ids, |
|
"labels": labels, |
|
"attention_mask": attention_mask, |
|
} |
|
else: |
|
logging.warning( |
|
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" |
|
) |
|
buffer = { |
|
"input_ids": [], |
|
"attention_mask": [], |
|
"labels": [], |
|
} |
|
buffer_len = 0 |
|
|
|
if example: |
|
|
|
if len(example["input_ids"]) <= self.seq_length: |
|
input_ids = example["input_ids"] |
|
attention_mask = example["attention_mask"] |
|
labels = example["labels"] |
|
|
|
if add_concat_token: |
|
input_ids.append(self.concat_token_id) |
|
attention_mask.append(1) |
|
labels.append(self.concat_token_id) |
|
|
|
input_ids_with_concat = torch.tensor( |
|
input_ids, dtype=self.tokens_dtype |
|
) |
|
attention_mask_with_concat = torch.tensor( |
|
attention_mask, dtype=self.tokens_dtype |
|
) |
|
labels_with_concat = torch.tensor( |
|
labels, dtype=self.tokens_dtype |
|
) |
|
|
|
buffer["input_ids"].append(input_ids_with_concat) |
|
buffer["attention_mask"].append(attention_mask_with_concat) |
|
buffer["labels"].append(labels_with_concat) |
|
buffer_len += len(input_ids) |
|
|