|
import math |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional |
|
|
|
import nltk |
|
import numpy as np |
|
from numpy.random import permutation, poisson |
|
from transformers.data.data_collator import _torch_collate_batch |
|
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase |
|
|
|
nltk.download("punkt") |
|
|
|
|
|
@dataclass |
|
class DataCollatorForTextInfilling: |
|
tokenizer: PreTrainedTokenizerBase |
|
mlm_probability: float = 0.15 |
|
poisson_lambda: float = 3.0 |
|
pad_to_multiple_of: Optional[int] = None |
|
|
|
def __post_init__(self): |
|
if self.tokenizer.mask_token is None: |
|
raise ValueError |
|
|
|
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: |
|
|
|
batch = {} |
|
if isinstance(examples, (dict, BatchEncoding)): |
|
examples_ids = examples["input_ids"] |
|
if "decoder_input_ids" in examples.keys(): |
|
examples_dec = examples["decoder_input_ids"] |
|
else: |
|
examples_dec = examples_ids |
|
|
|
|
|
if type(examples_ids[0]) is int: |
|
examples_ids = [examples_ids] |
|
|
|
if type(examples_dec[0]) is int: |
|
examples_dec = [examples_dec] |
|
|
|
batch["input_ids"] = _torch_collate_batch( |
|
examples_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of |
|
) |
|
batch["decoder_input_ids"] = _torch_collate_batch( |
|
examples_dec, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of |
|
) |
|
batch["decoder_input_ids"] = batch["decoder_input_ids"].tolist() |
|
|
|
elif isinstance(examples[0], (dict, BatchEncoding)): |
|
batch = self.tokenizer.pad(examples, return_tensors="jax", pad_to_multiple_of=self.pad_to_multiple_of) |
|
else: |
|
batch["input_ids"] = _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) |
|
batch["decoder_input_ids"] = _torch_collate_batch( |
|
examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of |
|
).tolist() |
|
|
|
|
|
special_tokens_mask = batch.pop("special_tokens_mask", None) |
|
|
|
batch["input_ids"], batch["labels"] = self.mask_tokens( |
|
batch["input_ids"], special_tokens_mask=special_tokens_mask |
|
) |
|
|
|
return batch |
|
|
|
def mask_tokens(self, inputs): |
|
inputs_copy = np.array(inputs) |
|
labels = np.array(inputs) |
|
if special_tokens_mask is None: |
|
special_tokens_mask = [ |
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
|
] |
|
special_tokens_mask = jnp.array(special_tokens_mask, dtype=bool) |
|
else: |
|
special_tokens_mask = special_tokens_mask.bool() |
|
|
|
|
|
is_token = ~(labels == self.tokenizer.pad_token_id) & ~special_tokens_mask |
|
num_to_mask = int(math.ceil(is_token.astype(float).sum() * self.mlm_probability)) |
|
if num_to_mask == 0: |
|
return inputs, labels |
|
|
|
|
|
lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask,)) |
|
while np.cumsum(lengths, 0)[-1] < num_to_mask: |
|
lengths = np.concatenate([lengths, poisson(lam=self.poisson_lambda, size=(num_to_mask,))]) |
|
|
|
|
|
|
|
|
|
lengths = lengths[lengths > 0] |
|
|
|
|
|
idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1 |
|
lengths = lengths[: idx + 1] |
|
|
|
|
|
|
|
|
|
|
|
token_indices = np.argwhere(is_token == 1) |
|
|
|
|
|
span_starts = permutation(token_indices.shape[0])[: lengths.shape[0]] |
|
|
|
|
|
masked_indices = np.array(token_indices[span_starts]) |
|
|
|
|
|
mask = np.full_like(labels, fill_value=False) |
|
|
|
|
|
for mi in masked_indices: |
|
mask[tuple(mi)] = True |
|
lengths -= 1 |
|
|
|
|
|
max_index = labels.shape[1] - 1 |
|
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index) |
|
while np.any(remaining): |
|
masked_indices[remaining, 1] += 1 |
|
for mi in masked_indices: |
|
mask[tuple(mi)] = True |
|
lengths -= 1 |
|
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index) |
|
|
|
|
|
mask[np.where(special_tokens_mask == True)] = False |
|
inputs_copy[np.where(mask == 1)] = self.tokenizer.mask_token_id |
|
labels[np.where(mask == 0)] = -100 |
|
|
|
|
|
to_remove = (mask == 1) & np.roll((mask == 1), 1, 1) |
|
new_inputs = np.full_like(labels, fill_value=self.tokenizer.pad_token_id) |
|
|
|
|
|
for i, example in enumerate(np.split(inputs_copy, indices_or_sections=new_inputs.shape[0], axis=0)): |
|
new_example = example[0][~to_remove[i]] |
|
new_inputs[i, 0 : new_example.shape[0]] = new_example |
|
|
|
|
|
return new_inputs.tolist(), labels.tolist() |
|
|
|
|
|
|
|
@dataclass |
|
class SentenceTokenize: |
|
"""Tokenize documents into sentences, add bos and eos tokens and split sentences into smaller chunks if too long.""" |
|
|
|
sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle") |
|
bos: str = "<s>" |
|
eos: str = "</s>" |
|
max_sentences = 256 |
|
sentence_stride = 128 |
|
max_characters = 100000 |
|
|
|
def __call__(self, examples: Dict[str, List[str]]) -> Dict[str, List[str]]: |
|
is_batched = isinstance(examples["text"], list) |
|
if not is_batched: |
|
|
|
examples["text"] = [examples["text"]] |
|
|
|
texts = [] |
|
|
|
for doc in examples["text"]: |
|
sentences = self.sentence_tokenizer.tokenize(doc) |
|
start_index = 0 |
|
|
|
|
|
while start_index < len(sentences): |
|
sentence_span = sentences[start_index : min(len(sentences), start_index + self.max_sentences)] |
|
text = f"{self.eos}{self.bos}".join([sentence for sentence in sentence_span]) |
|
|
|
|
|
if len(text) > self.max_characters: |
|
text = text[: self.max_characters] |
|
texts.append(text) |
|
start_index += self.sentence_stride |
|
|
|
|
|
return {"text": texts} |
|
|
|
|
|
@dataclass |
|
class DataCollatorForSentencePermutation: |
|
tokenizer: PreTrainedTokenizerBase |
|
permutate_sentence_ratio: float = 1.0 |
|
|
|
def __post_init__(self): |
|
self.full_stop_index = self.tokenizer.eos_token_id |
|
|
|
def __call__(self, example: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: |
|
source = example["input_ids"] |
|
|
|
full_stops = source == self.full_stop_index |
|
|
|
|
|
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2 |
|
result = source.copy() |
|
|
|
num_sentences = jnp.size(sentence_ends, 0) |
|
num_to_permute = math.ceil((num_sentences * 2 * self.permutate_sentence_ratio) / 2.0) |
|
substitutions = random.permutation(self.random_key, num_sentences)[:num_to_permute] |
|
ordering = jnp.arange(0, num_sentences) |
|
ordering = ops.index_update( |
|
ordering, substitutions, substitutions[random.permutation(self.random_key, num_to_permute)] |
|
) |
|
|
|
index = 0 |
|
for i in ordering: |
|
sentence = source[(sentence_ends[i - 1] if i > 0 else 0) : sentence_ends[i]] |
|
result = ops.index_update(result, ops.index[index : index + jnp.size(sentence, 0)], sentence) |
|
index += jnp.size(sentence, 0) |
|
|
|
example["decoder_input_ids"] = example["input_ids"] |
|
example["input_ids"] = result |
|
|
|
return example |
|
|
|
|
|
@dataclass |
|
class DataCollatorForDenoisingTasks: |
|
"""Data collator used denoising language modeling task in BART. |
|
The implementation is based on |
|
https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py. |
|
The default paramters is based on BART paper https://arxiv.org/abs/1910.13461. |
|
""" |
|
|
|
tokenizer: PreTrainedTokenizerBase |
|
mask_ratio: float = 0.3 |
|
poisson_lambda: float = 3.0 |
|
permutate_sentence_ratio: float = 1.0 |
|
pad_to_multiple_of: int = 16 |
|
|
|
def __post_init__(self): |
|
if self.tokenizer.mask_token is None or self.tokenizer.eos_token is None: |
|
raise ValueError |
|
|
|
def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, np.ndarray]: |
|
"""Batching, adding whole word mask and permutate sentences |
|
Args: |
|
examples (dict): list of examples each examples contains input_ids field |
|
""" |
|
|
|
batch = self.tokenizer.pad(examples, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="np") |
|
batch["decoder_input_ids"] = self.shift_tokens_right(batch["input_ids"]) |
|
|
|
do_permutate = False |
|
if self.permutate_sentence_ratio > 0.0: |
|
batch["input_ids"] = self.permutate_sentences(batch["input_ids"]) |
|
do_permutate = True |
|
|
|
if self.mask_ratio: |
|
batch["input_ids"], batch["labels"] = self.add_whole_word_mask(batch["input_ids"], do_permutate) |
|
|
|
return batch |
|
|
|
def shift_tokens_right(self, inputs): |
|
"""Shift decoder input ids right: https://github.com/huggingface/transformers/issues/7961. |
|
Examples: |
|
<s>My dog is cute.</s><s>It loves to play in the park.</s><pad><pad> |
|
shift to -> </s><s>My dog is cute.</s><s>It loves to play in the park.<pad><pad> |
|
""" |
|
|
|
shifted_inputs = np.roll(inputs, 1, axis=-1) |
|
|
|
|
|
shifted_inputs[:, 0] = self.tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
|
|
|
end_with_eos = np.where(shifted_inputs[:, -1] == self.tokenizer.eos_token_id) |
|
shifted_inputs[end_with_eos, -1] = self.tokenizer.pad_token_id |
|
|
|
|
|
last_eos_indices = np.where( |
|
(shifted_inputs[:, :-1] == self.tokenizer.eos_token_id) |
|
* (shifted_inputs[:, 1:] == self.tokenizer.pad_token_id) |
|
) |
|
|
|
|
|
shifted_inputs[last_eos_indices] = self.tokenizer.pad_token_id |
|
return shifted_inputs |
|
|
|
def permutate_sentences(self, inputs): |
|
results = inputs.copy() |
|
|
|
full_stops = inputs == self.tokenizer.eos_token_id |
|
|
|
sentence_ends = np.argwhere(full_stops[:, 1:] * ~full_stops[:, :-1]) |
|
sentence_ends[:, 1] += 2 |
|
num_sentences = np.unique(sentence_ends[:, 0], return_counts=True)[1] |
|
num_to_permute = np.ceil((num_sentences * 2 * self.permutate_sentence_ratio) / 2.0).astype(int) |
|
|
|
sentence_ends = np.split(sentence_ends[:, 1], np.unique(sentence_ends[:, 0], return_index=True)[1][1:]) |
|
|
|
for i in range(inputs.shape[0]): |
|
substitutions = np.random.permutation(num_sentences[i])[: num_to_permute[i]] |
|
|
|
ordering = np.arange(0, num_sentences[i]) |
|
ordering[substitutions] = substitutions[np.random.permutation(num_to_permute[i])] |
|
|
|
index = 0 |
|
for j in ordering: |
|
sentence = inputs[i, (sentence_ends[i][j - 1] if j > 0 else 0) : sentence_ends[i][j]] |
|
results[i, index : index + sentence.shape[0]] = sentence |
|
index += sentence.shape[0] |
|
return results |
|
|
|
def add_whole_word_mask(self, inputs, do_permutate): |
|
labels = inputs.copy() |
|
|
|
special_tokens_mask = [ |
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
|
] |
|
special_tokens_mask = np.array(special_tokens_mask, dtype=bool) |
|
|
|
|
|
is_token = ~(labels == self.tokenizer.pad_token_id) & ~special_tokens_mask |
|
num_to_mask = int(math.ceil(is_token.astype(float).sum() * self.mask_ratio)) |
|
if num_to_mask == 0: |
|
return inputs, labels |
|
|
|
|
|
lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask,)) |
|
while np.cumsum(lengths, 0)[-1] < num_to_mask: |
|
lengths = np.concatenate([lengths, poisson(lam=self.poisson_lambda, size=(num_to_mask,))]) |
|
|
|
|
|
|
|
|
|
lengths = lengths[lengths > 0] |
|
|
|
|
|
idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1 |
|
lengths = lengths[: idx + 1] |
|
|
|
|
|
|
|
|
|
|
|
token_indices = np.argwhere(is_token == 1) |
|
|
|
|
|
span_starts = permutation(token_indices.shape[0])[: lengths.shape[0]] |
|
|
|
|
|
masked_indices = np.array(token_indices[span_starts]) |
|
|
|
|
|
mask = np.full_like(labels, fill_value=False) |
|
|
|
|
|
for mi in masked_indices: |
|
mask[tuple(mi)] = True |
|
lengths -= 1 |
|
|
|
|
|
max_index = labels.shape[1] - 1 |
|
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index) |
|
while np.any(remaining): |
|
masked_indices[remaining, 1] += 1 |
|
for mi in masked_indices: |
|
mask[tuple(mi)] = True |
|
lengths -= 1 |
|
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index) |
|
|
|
|
|
mask[np.where(special_tokens_mask)] = False |
|
inputs[np.where(mask)] = self.tokenizer.mask_token_id |
|
|
|
if not do_permutate: |
|
labels[np.where(mask)] = -100 |
|
else: |
|
labels[np.where(special_tokens_mask)] = -100 |
|
|
|
|
|
to_remove = (mask == 1) & np.roll((mask == 1), 1, 1) |
|
new_inputs = np.full_like(labels, fill_value=self.tokenizer.pad_token_id) |
|
|
|
|
|
for i, example in enumerate(np.split(inputs, indices_or_sections=new_inputs.shape[0], axis=0)): |
|
new_example = example[0][~to_remove[i]] |
|
new_inputs[i, 0 : new_example.shape[0]] = new_example |
|
|
|
|
|
return new_inputs, labels |
|
|