flax-bart-nb-nn / sentence_permutation.py
pere's picture
fisrt commit
e565538
raw
history blame
No virus
979 Bytes
import math
from dataclasses import dataclass
from typing import Dict
import jax.numpy as jnp
import nltk
import numpy as np
from jax import ops, random
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from data_collator import DataCollatorForSentencePermutation, DataCollatorForTextInfilling, SentenceTokenize
example = {"text": " My dog is cute. It loves to play in the park. There are many parks in SF."}
sent_tok = SentenceTokenize()
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
permuate_sent = DataCollatorForSentencePermutation(tokenizer)
example = sent_tok(example)
print(example["text"])
out = permuate_sent(tokenizer(example["text"], add_special_tokens=False))
example["text"] = tokenizer.decode(out["input_ids"])
print(example["text"])
masking = DataCollatorForTextInfilling(tokenizer)
out = masking(out)
example["text"] = tokenizer.decode(out["input_ids"][0])
print(example["text"])