File size: 979 Bytes
e565538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import math
from dataclasses import dataclass
from typing import Dict

import jax.numpy as jnp
import nltk
import numpy as np
from jax import ops, random
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from data_collator import DataCollatorForSentencePermutation, DataCollatorForTextInfilling, SentenceTokenize

example = {"text": " My dog is cute. It loves to play in the park. There are many parks in SF."}
sent_tok = SentenceTokenize()
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
permuate_sent = DataCollatorForSentencePermutation(tokenizer)
example = sent_tok(example)
print(example["text"])
out = permuate_sent(tokenizer(example["text"], add_special_tokens=False))
example["text"] = tokenizer.decode(out["input_ids"])
print(example["text"])
masking = DataCollatorForTextInfilling(tokenizer)
out = masking(out)
example["text"] = tokenizer.decode(out["input_ids"][0])
print(example["text"])