text_generate / app.py
KJMAN678's picture
first commit
bba3c8d
raw
history blame
2.48 kB
import streamlit as st
from transformers import T5Tokenizer, AutoModelForCausalLM
def cached_tokenizer():
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
return tokenizer
def cached_model():
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
return model
def main():
st.title("GPT-2ใซใ‚ˆใ‚‹ๆ—ฅๆœฌ่ชžใฎๆ–‡็ซ ็”Ÿๆˆ")
num_of_output_text = st.slider(label='ๅ‡บๅŠ›ใ™ใ‚‹ๆ–‡็ซ ใฎๆ•ฐ',
min_value=1,
max_value=2,
value=1,
)
length_of_output_text = st.slider(label='ๅ‡บๅŠ›ใ™ใ‚‹ๆ–‡ๅญ—ๆ•ฐ',
min_value=30,
max_value=200,
value=100,
)
PREFIX_TEXT = st.text_area(
label='ใƒ†ใ‚ญใ‚นใƒˆๅ…ฅๅŠ›',
value='ๅพ่ผฉใฏ็Œซใงใ‚ใ‚‹'
)
progress_num = 0
status_text = st.empty()
progress_bar = st.progress(progress_num)
if st.button('ๆ–‡็ซ ็”Ÿๆˆ'):
st.text("่ชญใฟ่พผใฟใซๆ™‚้–“ใŒใ‹ใ‹ใ‚Šใพใ™")
progress_num = 10
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
tokenizer = cached_tokenizer()
progress_num = 25
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
model = cached_model()
progress_num = 40
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
# ๆŽจ่ซ–
input = tokenizer.encode(PREFIX_TEXT, return_tensors="pt")
progress_num = 60
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
output = model.generate(
input, do_sample=True,
max_length=length_of_output_text,
num_return_sequences=num_of_output_text
)
progress_num = 90
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
output_text = "".join(tokenizer.batch_decode(output)).replace("</s>", "")
output_text = output_text.replace("</unk>", "")
progress_num = 95
status_text.text(f'Progress: {progress_num}%')
progress_bar.progress(progress_num)
st.info('็”Ÿๆˆ็ตๆžœ')
progress_num = 100
status_text.text(f'Progress: {progress_num}%')
st.write(output_text)
progress_bar.progress(progress_num)
if __name__ == '__main__':
main()