import torch import string import streamlit as st from transformers import GPT2Config, GPT2LMHeadModel pretrained_model_config = GPT2Config.from_pretrained( args.pretrained_model_name, ) model = GPT2LMHeadModel(pretrained_model_config) fine_tuned_model_ckpt = torch.load( args.downstream_model_checkpoint_fpath, map_location=torch.device("cpu"), ) model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()}) model.eval() from transformers import PreTrainedTokenizerFast tokenizer = PreTrainedTokenizerFast.from_pretrained( args.pretrained_model_name, eos_token="", ) default_text = "안녕하세요?" text = st.text_area("Input Text:", value=default_text) st.write(text) punct = ('!', '?', '.') if text: st.markdown("## Predict") with st.spinner('processing..'): print(f'input > {text}') input_ids = tokenizer(text)['input_ids'] gen_ids = model.generate(torch.tensor([input_ids]), max_length=128, repetition_penalty=2.0) generated = tokenizer.decode(gen_ids[0,:].tolist()).strip() if generated != '' and generated[-1] not in punct: for i in reversed(range(len(generated))): if generated[i] in punct: break generated = generated[:(i+1)] print(f'KoGPT > {generated}') st.write(generated)