parkjaewoong
Upload files for streamlit demo
c3fed12
raw history blame
No virus
1.09 kB
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2TokenizerFast
import numpy as np
import jax
import jax.numpy as jnp
st.title("GPT2-korean")
# st.markdown()
seed = st.text_input("Seed", "μ•ˆλ…•ν•˜μ„Έμš”")
go = st.button("Generate")
if go:
model = GPT2LMHeadModel.from_pretrained(".", pad_token_id=50256, from_flax=True)
tokenizer = GPT2TokenizerFast.from_pretrained(
".", padding_side="left", pad_token="<|endoftext|>"
)
input_context = seed
input_ids = tokenizer(input_context, return_tensors="pt")
outputs = model.generate(
input_ids=input_ids["input_ids"],
max_length=50,
num_return_sequences=1,
num_beams=3,
no_repeat_ngram_size=3,
repetition_penalty=2.0,
do_sample=True,
bad_words_ids=[
[95],
[5470],
[504],
[528],
[504],
[919],
[65, 20374, 63],
[655],
],
)
st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True))