import math from dataclasses import dataclass from typing import Dict import jax.numpy as jnp import nltk import numpy as np from jax import ops, random from transformers import AutoTokenizer from transformers.tokenization_utils_base import PreTrainedTokenizerBase from data_collator import DataCollatorForSentencePermutation, DataCollatorForTextInfilling, SentenceTokenize example = {"text": " My dog is cute. It loves to play in the park. There are many parks in SF."} sent_tok = SentenceTokenize() tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") permuate_sent = DataCollatorForSentencePermutation(tokenizer) example = sent_tok(example) print(example["text"]) out = permuate_sent(tokenizer(example["text"], add_special_tokens=False)) example["text"] = tokenizer.decode(out["input_ids"]) print(example["text"]) masking = DataCollatorForTextInfilling(tokenizer) out = masking(out) example["text"] = tokenizer.decode(out["input_ids"][0]) print(example["text"])