KJMAN678 commited on
Commit
bba3c8d
1 Parent(s): ac31d19

first commit

Browse files
Files changed (5) hide show
  1. .gitignore +0 -0
  2. README.md +1 -1
  3. app.py +82 -0
  4. command.txt +10 -0
  5. requirements.txt +4 -0
.gitignore ADDED
File without changes
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Megane_otoko
3
  emoji: 💩
4
  colorFrom: yellow
5
  colorTo: purple
1
  ---
2
+ title: Text Generation by GPT-3
3
  emoji: 💩
4
  colorFrom: yellow
5
  colorTo: purple
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5Tokenizer, AutoModelForCausalLM
3
+
4
+ def cached_tokenizer():
5
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
6
+ tokenizer.do_lower_case = True
7
+ return tokenizer
8
+
9
+ def cached_model():
10
+ model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
11
+ return model
12
+
13
+ def main():
14
+ st.title("GPT-2による日本語の文章生成")
15
+
16
+ num_of_output_text = st.slider(label='出力する文章の数',
17
+ min_value=1,
18
+ max_value=2,
19
+ value=1,
20
+ )
21
+
22
+ length_of_output_text = st.slider(label='出力する文字数',
23
+ min_value=30,
24
+ max_value=200,
25
+ value=100,
26
+ )
27
+
28
+ PREFIX_TEXT = st.text_area(
29
+ label='テキスト入力',
30
+ value='吾輩は猫である'
31
+ )
32
+
33
+ progress_num = 0
34
+ status_text = st.empty()
35
+ progress_bar = st.progress(progress_num)
36
+
37
+ if st.button('文章生成'):
38
+
39
+ st.text("読み込みに時間がかかります")
40
+ progress_num = 10
41
+ status_text.text(f'Progress: {progress_num}%')
42
+ progress_bar.progress(progress_num)
43
+
44
+ tokenizer = cached_tokenizer()
45
+ progress_num = 25
46
+ status_text.text(f'Progress: {progress_num}%')
47
+ progress_bar.progress(progress_num)
48
+
49
+ model = cached_model()
50
+ progress_num = 40
51
+ status_text.text(f'Progress: {progress_num}%')
52
+ progress_bar.progress(progress_num)
53
+
54
+ # 推論
55
+ input = tokenizer.encode(PREFIX_TEXT, return_tensors="pt")
56
+ progress_num = 60
57
+ status_text.text(f'Progress: {progress_num}%')
58
+ progress_bar.progress(progress_num)
59
+
60
+ output = model.generate(
61
+ input, do_sample=True,
62
+ max_length=length_of_output_text,
63
+ num_return_sequences=num_of_output_text
64
+ )
65
+ progress_num = 90
66
+ status_text.text(f'Progress: {progress_num}%')
67
+ progress_bar.progress(progress_num)
68
+
69
+ output_text = "".join(tokenizer.batch_decode(output)).replace("</s>", "")
70
+ output_text = output_text.replace("</unk>", "")
71
+ progress_num = 95
72
+ status_text.text(f'Progress: {progress_num}%')
73
+ progress_bar.progress(progress_num)
74
+
75
+ st.info('生成結果')
76
+ progress_num = 100
77
+ status_text.text(f'Progress: {progress_num}%')
78
+ st.write(output_text)
79
+ progress_bar.progress(progress_num)
80
+
81
+ if __name__ == '__main__':
82
+ main()
command.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ ## ローカル実行用のコマンド
2
+
3
+ ## pip のアップグレード
4
+ python -m pip install --upgrade pip
5
+
6
+ ## requirements.txt からパッケージをインストール
7
+ pip install -r requirements.txt
8
+
9
+ ## ローカルサーバーの立上げ
10
+ streamlit run app.py
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ sentencepiece==0.1.96
2
+ transformers==4.12.2
3
+ streamlit==1.1.0
4
+ torch==1.10.0