import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset class TokenizedDataset(Dataset): """ Converts a dataset of text samples into a dataset of token sequences, as converted by a supplied tokenizer. The tokens come along with position ids and attention masks, they can be supplied direcly to the model. """ def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"): self.text_dataset = text_dataset self.field = field self.tokenizer = tokenizer self.maxlen = maxlen if hasattr(text_dataset, "info"): self.info = text_dataset.info def __len__(self): return len(self.text_dataset) def __getitem__(self, i): text = self.text_dataset[i] if self.field is not None: text = text[self.field] token_list = self.tokenizer.encode( text, truncation=True, max_length=self.maxlen ) position_ids = list(range(len(token_list))) attention_mask = [1] * len(token_list) return dict( input_ids=torch.tensor(token_list), position_ids=torch.tensor(position_ids), attention_mask=torch.tensor(attention_mask), ) def dict_to_(data, device): """ Moves a dictionary of tensors to the specified device. """ for k in data: data[k] = data[k].to(device) return data def length_collation(token_size): """ Sorts a batch of sequences and breaks it up into subbatches of same-sized sequences, padding as needed. Each batch has no more than token_size total tokens (or a single sequence, if the sequence happens to be larger). """ def collate_fn(items): items = sorted(items, key=lambda x: -len(x["input_ids"])) batches = [] batch = [] batch_width = 0 for item in items: item_width = len(item["input_ids"]) if item_width == 0: break if batch_width * (len(batch) + 1) > token_size: batches.append(make_padded_batch(batch)) batch = [] batch_width = 0 if not batch: batch_width = item_width batch.append(item) if len(batch): batches.append(make_padded_batch(batch)) return batches return collate_fn def make_padded_batch(items): """ Pads sequences in a batch, so they are all the same length as the longest. """ max_len = max(len(d["input_ids"]) for d in items) if max_len == 0: return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]} return { k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True) for k, v in items[0].items() } def flatten_masked_batch(data, mask): """ Flattens feature data, ignoring items that are masked out of attention. """ flat_data = data.view(-1, data.size(-1)) attended_tokens = mask.view(-1).nonzero()[:, 0] return flat_data[attended_tokens]