Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2TokenizerFast | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
st.title("GPT2-korean") | |
# st.markdown() | |
seed = st.text_input("Seed", "μλ νμΈμ") | |
go = st.button("Generate") | |
if go: | |
model = GPT2LMHeadModel.from_pretrained(".", pad_token_id=50256, from_flax=True) | |
tokenizer = GPT2TokenizerFast.from_pretrained( | |
".", padding_side="left", pad_token="<|endoftext|>" | |
) | |
input_context = seed | |
input_ids = tokenizer(input_context, return_tensors="pt") | |
outputs = model.generate( | |
input_ids=input_ids["input_ids"], | |
max_length=50, | |
num_return_sequences=1, | |
num_beams=3, | |
no_repeat_ngram_size=3, | |
repetition_penalty=2.0, | |
do_sample=True, | |
bad_words_ids=[ | |
[95], | |
[5470], | |
[504], | |
[528], | |
[504], | |
[919], | |
[65, 20374, 63], | |
[655], | |
], | |
) | |
st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |