Spaces:
Runtime error
Runtime error
File size: 2,484 Bytes
bba3c8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import streamlit as st
from transformers import T5Tokenizer, AutoModelForCausalLM
def cached_tokenizer():
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
return tokenizer
def cached_model():
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
return model
def main():
st.title("GPT-2による日本語の文章生成")
num_of_output_text = st.slider(label='出力する文章の数',
min_value=1,
max_value=2,
value=1,
)
length_of_output_text = st.slider(label='出力する文字数',
min_value=30,
max_value=200,
value=100,
)
PREFIX_TEXT = st.text_area(
label='テキスト入力',
value='吾輩は猫である'
)
progress_num = 0
status_text = st.empty()
progress_bar = st.progress(progress_num)
if st.button('文章生成'):
st.text("読み込みに時間がかかります")
progress_num = 10
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
tokenizer = cached_tokenizer()
progress_num = 25
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
model = cached_model()
progress_num = 40
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
# 推論
input = tokenizer.encode(PREFIX_TEXT, return_tensors="pt")
progress_num = 60
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
output = model.generate(
input, do_sample=True,
max_length=length_of_output_text,
num_return_sequences=num_of_output_text
)
progress_num = 90
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
output_text = "".join(tokenizer.batch_decode(output)).replace("</s>", "")
output_text = output_text.replace("</unk>", "")
progress_num = 95
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
st.info('生成結果')
progress_num = 100
status_text.text(f'Progress: {progress_num}%')
st.write(output_text)
progress_bar.progress(progress_num)
if __name__ == '__main__':
main() |