Spaces:
Runtime error
Runtime error
from transformers import PreTrainedTokenizerFast | |
from tokenizers import SentencePieceBPETokenizer | |
from transformers import BartForConditionalGeneration | |
import streamlit as st | |
import torch | |
def tokenizer(): | |
tokenizer = PreTrainedTokenizerFast.from_pretrained('Soyoung97/gec_kr') | |
return tokenizer | |
def get_model(): | |
model = BartForConditionalGeneration.from_pretrained('Soyoung97/gec_kr') | |
model.eval() | |
return model | |
default_text = 'λλ μ€λ μ§λ² κ°μ¨μ' | |
model = get_model() | |
tokenizer = tokenizer() | |
st.title("GEC_KR Model Test") | |
text = st.text_area("Input corrputed sentence :", value=default_text) | |
st.markdown("## Original sentence:") | |
st.write(text) | |
if text: | |
st.markdown("## Corrected output") | |
with st.spinner('processing..'): | |
raw_input_ids = tokenizer.encode(text) | |
input_ids = [tokenizer.bos_token_id] + \ | |
raw_input_ids + [tokenizer.eos_token_id] | |
corrected_ids = model.generate(torch.tensor([input_ids]), | |
max_length=256, | |
eos_token_id=1, | |
num_beams=4, | |
early_stopping=True, | |
repetition_penalty=2.0) | |
summ = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True) | |
st.write(summ) | |