File size: 2,484 Bytes
bba3c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from transformers import T5Tokenizer, AutoModelForCausalLM

def cached_tokenizer():
    tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
    tokenizer.do_lower_case = True
    return tokenizer

def cached_model():
    model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
    return model

def main():
  st.title("GPT-2による日本語の文章生成")

  num_of_output_text = st.slider(label='出力する文章の数',
                  min_value=1,
                  max_value=2,
                  value=1,
                  )

  length_of_output_text = st.slider(label='出力する文字数',
                  min_value=30,
                  max_value=200,
                  value=100,
                  )

  PREFIX_TEXT = st.text_area(
        label='テキスト入力', 
        value='吾輩は猫である'
  )

  progress_num = 0
  status_text = st.empty()
  progress_bar = st.progress(progress_num)

  if st.button('文章生成'):

    st.text("読み込みに時間がかかります")
    progress_num = 10
    status_text.text(f'Progress: {progress_num}%')
    progress_bar.progress(progress_num)

    tokenizer = cached_tokenizer()
    progress_num = 25
    status_text.text(f'Progress: {progress_num}%')
    progress_bar.progress(progress_num)

    model = cached_model()
    progress_num = 40
    status_text.text(f'Progress: {progress_num}%')
    progress_bar.progress(progress_num)

    # 推論 
    input = tokenizer.encode(PREFIX_TEXT, return_tensors="pt") 
    progress_num = 60
    status_text.text(f'Progress: {progress_num}%')
    progress_bar.progress(progress_num)

    output = model.generate(
            input, do_sample=True, 
            max_length=length_of_output_text,
            num_return_sequences=num_of_output_text
            )
    progress_num = 90
    status_text.text(f'Progress: {progress_num}%')
    progress_bar.progress(progress_num)

    output_text = "".join(tokenizer.batch_decode(output)).replace("</s>", "")
    output_text = output_text.replace("</unk>", "")
    progress_num = 95
    status_text.text(f'Progress: {progress_num}%')
    progress_bar.progress(progress_num)

    st.info('生成結果')
    progress_num = 100
    status_text.text(f'Progress: {progress_num}%')
    st.write(output_text)
    progress_bar.progress(progress_num)

if __name__ == '__main__':
  main()