File size: 1,594 Bytes
f0c636b 6baca04 f0c636b 6baca04 faa4aa2 6baca04 699251d 6baca04 b9e0b01 6baca04 914f0b6 6baca04 3ee3f83 6baca04 9ba0dd3 6baca04 2d278af 9ba0dd3 6baca04 9ba0dd3 6baca04 a46e61a 6baca04 9ba0dd3 6baca04 991fe21 6baca04 8068f7e 6baca04 699251d 6baca04 252e169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
import torch
import re
model = None
tokenizer = None
def init():
from transformers import MT5ForConditionalGeneration, T5TokenizerFast
import os
global model, tokenizer
hf_token = os.environ.get("HF_TOKEN")
model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
model.eval()
tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
def correct(text):
text = re.sub(r'\u200d', '<ZWJ>', text)
inputs = tokenizer(
text,
return_tensors='pt',
padding='do_not_pad',
max_length=1024
)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=1024,
num_beams=1,
do_sample=False,
)
prediction = outputs[0]
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
all_special_ids = set(tokenizer.all_special_ids)
pred_tokens = prediction.cpu()
tokens_list = pred_tokens.tolist()
filtered_tokens = [
token for token in tokens_list
if token == special_token_id_to_keep or token not in all_special_ids
]
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
init()
demo = gr.Interface(fn=correct, inputs="text", outputs="text")
demo.launch()
|