import torch import streamlit as st from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2TokenizerFast import numpy as np import jax import jax.numpy as jnp st.title("GPT2-korean") # st.markdown() seed = st.text_input("Seed", "안녕하세요") go = st.button("Generate") if go: model = GPT2LMHeadModel.from_pretrained(".", pad_token_id=50256, from_flax=True) tokenizer = GPT2TokenizerFast.from_pretrained( ".", padding_side="left", pad_token="<|endoftext|>" ) input_context = seed input_ids = tokenizer(input_context, return_tensors="pt") outputs = model.generate( input_ids=input_ids["input_ids"], max_length=50, num_return_sequences=1, num_beams=3, no_repeat_ngram_size=3, repetition_penalty=2.0, do_sample=True, bad_words_ids=[ [95], [5470], [504], [528], [504], [919], [65, 20374, 63], [655], ], ) st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True))