mechanical-bird / app.py
mansiksohn's picture
Update app.py
b1de362
import torch
import string
import streamlit as st
from transformers import GPT2Config, GPT2LMHeadModel
pretrained_model_config = GPT2Config.from_pretrained(
args.pretrained_model_name,
)
model = GPT2LMHeadModel(pretrained_model_config)
fine_tuned_model_ckpt = torch.load(
args.downstream_model_checkpoint_fpath,
map_location=torch.device("cpu"),
)
model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()})
model.eval()
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained(
args.pretrained_model_name,
eos_token="</s>",
)
default_text = "μ•ˆλ…•ν•˜μ„Έμš”?"
text = st.text_area("Input Text:", value=default_text)
st.write(text)
punct = ('!', '?', '.')
if text:
st.markdown("## Predict")
with st.spinner('processing..'):
print(f'input > {text}')
input_ids = tokenizer(text)['input_ids']
gen_ids = model.generate(torch.tensor([input_ids]),
max_length=128,
repetition_penalty=2.0)
generated = tokenizer.decode(gen_ids[0,:].tolist()).strip()
if generated != '' and generated[-1] not in punct:
for i in reversed(range(len(generated))):
if generated[i] in punct:
break
generated = generated[:(i+1)]
print(f'KoGPT > {generated}')
st.write(generated)