Spaces:
Runtime error
Runtime error
import torch | |
import string | |
import streamlit as st | |
from transformers import GPT2Config, GPT2LMHeadModel | |
pretrained_model_config = GPT2Config.from_pretrained( | |
args.pretrained_model_name, | |
) | |
model = GPT2LMHeadModel(pretrained_model_config) | |
fine_tuned_model_ckpt = torch.load( | |
args.downstream_model_checkpoint_fpath, | |
map_location=torch.device("cpu"), | |
) | |
model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()}) | |
model.eval() | |
from transformers import PreTrainedTokenizerFast | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
args.pretrained_model_name, | |
eos_token="</s>", | |
) | |
default_text = "μλ νμΈμ?" | |
text = st.text_area("Input Text:", value=default_text) | |
st.write(text) | |
punct = ('!', '?', '.') | |
if text: | |
st.markdown("## Predict") | |
with st.spinner('processing..'): | |
print(f'input > {text}') | |
input_ids = tokenizer(text)['input_ids'] | |
gen_ids = model.generate(torch.tensor([input_ids]), | |
max_length=128, | |
repetition_penalty=2.0) | |
generated = tokenizer.decode(gen_ids[0,:].tolist()).strip() | |
if generated != '' and generated[-1] not in punct: | |
for i in reversed(range(len(generated))): | |
if generated[i] in punct: | |
break | |
generated = generated[:(i+1)] | |
print(f'KoGPT > {generated}') | |
st.write(generated) |