Spaces:
Runtime error
Runtime error
File size: 1,091 Bytes
c3fed12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2TokenizerFast
import numpy as np
import jax
import jax.numpy as jnp
st.title("GPT2-korean")
# st.markdown()
seed = st.text_input("Seed", "์๋
ํ์ธ์")
go = st.button("Generate")
if go:
model = GPT2LMHeadModel.from_pretrained(".", pad_token_id=50256, from_flax=True)
tokenizer = GPT2TokenizerFast.from_pretrained(
".", padding_side="left", pad_token="<|endoftext|>"
)
input_context = seed
input_ids = tokenizer(input_context, return_tensors="pt")
outputs = model.generate(
input_ids=input_ids["input_ids"],
max_length=50,
num_return_sequences=1,
num_beams=3,
no_repeat_ngram_size=3,
repetition_penalty=2.0,
do_sample=True,
bad_words_ids=[
[95],
[5470],
[504],
[528],
[504],
[919],
[65, 20374, 63],
[655],
],
)
st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|