hajime9652 commited on
Commit
7fdb9d9
1 Parent(s): a949e08

first commit

Browse files
Files changed (2) hide show
  1. app.py +210 -0
  2. requirements.txt +67 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import T5Tokenizer, AutoModelForCausalLM, GenerationConfig
3
+
4
+ # 0. モデルとトークナイザーの定義
5
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
6
+ tokenizer.do_lower_case = True # rinna/japanese-gpt2特有のハック
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ "rinna/japanese-gpt2-small",
9
+ pad_token_id=tokenizer.eos_token_id # warningを避けるために、padにEOSトークンを割りあてる
10
+ )
11
+
12
+ # 1. Gradioのコンポーネントのイベント処理用の関数の定義
13
+ def generate(text, max_length, num_beams, p):
14
+ """初回のテキスト生成
15
+
16
+ テキスト生成を行うが、デコード方法によって異なる結果になることを示すための処理を行う。
17
+ 指定されたパラメタを使って、異なる4つデコード方法を同時に出力する。
18
+
19
+ Args:
20
+ text: str
21
+ Stateから取得(続きを生成するためのプロンプト)
22
+ max_length: int
23
+ Sliderから取得(全てのデコード方法に共通のパラメタ。生成する単語数)
24
+ num_beams: int
25
+ Sliderから取得(Beam Searchのパラメタ)
26
+ p: int
27
+ Sliderから取得(Top-p Samplingのパラメタ)
28
+
29
+ Returns:
30
+ tuple(str1, str2, str3)
31
+ str1: State(生成結果を入出力の状態に反映)
32
+ str2: TextArea(全文表示用のコンポーネントで使用)
33
+ str3: TextArea(今回生成した文を表示するコンポーネントで使用)
34
+ """
35
+ # テキスト生成用のconfigクラスを使って、4パターンの設定を定義する。
36
+ generate_config_list = [
37
+ GenerationConfig(
38
+ max_new_tokens=max_length,
39
+ no_repeat_ngram_size=3,
40
+ num_beams=1, # beam幅の設定、2以上ではbeam searchになる。
41
+ do_sample=False # Samplingの設定
42
+ ),
43
+ GenerationConfig(
44
+ max_new_tokens=max_length,
45
+ no_repeat_ngram_size=3,
46
+ num_beams=1,
47
+ do_sample=True
48
+ ),
49
+ GenerationConfig(
50
+ max_new_tokens=max_length,
51
+ no_repeat_ngram_size=3,
52
+ num_beams=num_beams,
53
+ do_sample=False
54
+ ),
55
+ GenerationConfig(
56
+ max_new_tokens=max_length,
57
+ no_repeat_ngram_size=3,
58
+ do_sample=True,
59
+ top_p=p # Top-p Samplingのパラメタの設定
60
+ )
61
+ ]
62
+ generated_texts = []
63
+
64
+ inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")["input_ids"]
65
+ for generate_config in generate_config_list:
66
+ # テキスト生成
67
+ output = model.generate(inputs, generation_config=generate_config)
68
+ generated = tokenizer.decode(output[0], skip_special_tokens=True)
69
+ # 読みやすくさの処理を行なって、リストに追加
70
+ generated_texts.append("。\n".join(generated.replace(" ", "").split("。")))
71
+
72
+ # gradioはtupleを想定している。これと同じ処理:return generated_texts[0], generated_texts[1], generated_texts[2]
73
+ # pythonのタプルは「,」によって生成される。丸括弧は省略可能。参考:https://note.nkmk.me/python-function-return-multiple-values/
74
+ return tuple(generated_texts)
75
+
76
+ def select_out1(out1):
77
+ """out1が生成された時に、out1を後続の処理のデフォルト値に入力
78
+ """
79
+ return out1, out1, out1
80
+
81
+ def select_out(radio, out1, out2, out3, out4):
82
+ """後続の処理に使用する、初回の処理結果を選択する
83
+ """
84
+ if radio == "1.Greedy":
85
+ out = out1
86
+ elif radio == "2.Sampling":
87
+ out = out2
88
+ elif radio == "3.Beam Search":
89
+ out = out3
90
+ else:
91
+ out = out4
92
+ return out, out, out
93
+
94
+ def generate_next(now_text, radio, max_length, num_beams, p):
95
+ """続き生成
96
+
97
+ これまで出力したテキストを入力して受け取り、続きを生成する。
98
+ デコード方法を指定することができるが、そのパラメタは初回のテキスト生成と同じになる。
99
+
100
+ Args:
101
+ now_text: str
102
+ Stateから取得(続きを生成するためのプロンプト)
103
+ radio: str
104
+ Radioから取得(使用するデコード方法の名前)
105
+ max_length: int
106
+ Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
107
+ num_beams: int
108
+ Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
109
+ p: int
110
+ Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
111
+
112
+ Returns:
113
+ next_text: str
114
+ State(生成結果を入出力の状態に反映)
115
+ next_text: str
116
+ TextArea(全文表示用のコンポーネントで使用)
117
+ gen_text: str
118
+ TextArea(今回生成した文を表示するコンポーネントで使用)
119
+ """
120
+ # デコード方法の指定に合わせて、cofingを定義
121
+ if radio == "1.Greedy":
122
+ generate_config = GenerationConfig(
123
+ max_new_tokens=max_length,
124
+ no_repeat_ngram_size=3,
125
+ num_beams=1,
126
+ do_sample=False
127
+ )
128
+ elif radio == "2.Sampling":
129
+ generate_config = GenerationConfig(
130
+ max_new_tokens=max_length,
131
+ no_repeat_ngram_size=3,
132
+ num_beams=1,
133
+ do_sample=True
134
+ )
135
+ elif radio == "3.Beam Search":
136
+ generate_config = GenerationConfig(
137
+ max_new_tokens=max_length,
138
+ no_repeat_ngram_size=3,
139
+ num_beams=num_beams,
140
+ do_sample=False
141
+ )
142
+ else:
143
+ generate_config = GenerationConfig(
144
+ max_new_tokens=max_length,
145
+ no_repeat_ngram_size=3,
146
+ do_sample=True,
147
+ top_p=p
148
+ )
149
+
150
+ # テキスト生成
151
+ inputs = tokenizer(now_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
152
+ output = model.generate(inputs, generation_config=generate_config)
153
+ generated = tokenizer.decode(output[0], skip_special_tokens=True)
154
+ # 結果の整形処理
155
+ next_text = "。\n".join(generated.replace(" ", "").split("。"))
156
+ gen_text = next_text[len(now_text)+1:] # 今回生成したテキストを抽出
157
+
158
+ return next_text, next_text, gen_text
159
+
160
+ # 2. GradioによるUI/イベント処理の定義
161
+ with gr.Blocks() as demo:
162
+ # 2.1. UI
163
+ gr.Markdown('''
164
+ # テキスト生成
165
+ テキストを入力すると、4パターンのデコード方法でテキスト生成を実行します。
166
+ ## 4つのパターン(入門編)
167
+ 1. Greedy: ビームサーチもサンプリングも行いません。毎回、最も確率の高い単語を選択します。
168
+ 2. Sampling: モデルによって与えられた語彙全体の確率分布に基づいて次の単語を選択します。
169
+ 3. Beam Search: 各タイムステップで複数の仮説を保持し、最終的に仮説ごとのシーケンス全体で最も高い確率を持つ仮説を選択します。
170
+ 4. Top-p Sampling: 2の方法に関して、確率の和がpになる最小の単語にフィルタリングすることで、確率が低い単語が選ばれる可能性を無くします。
171
+ ''')
172
+
173
+ with gr.Row(): # 行に分ける。なので、このブロック内にあるコンポーネントは横に並ぶ。
174
+ with gr.Column(): # さらに列に分ける。なので、このブロック内にあるコンポーネントは縦に並ぶ。
175
+ input_text = gr.Textbox(value="福岡のご飯は美味しい。", label="プロンプト")
176
+ max_length = gr.Slider(100, 1000, step=100, value=100, label="生成するテキストの長さ")
177
+ num_beams = gr.Slider(1, 10, step=1, value=6, label="beam幅")
178
+ p = gr.Slider(0, 1, step=0.01, value=0.92, label="p")
179
+ btn1 = gr.Button("4パターンで生成")
180
+
181
+ with gr.Column():
182
+ out1 = gr.Textbox(label="Greedy")
183
+ out2 = gr.Textbox(label="Sampling")
184
+ out3 = gr.Textbox(label="Beam Search")
185
+ out4 = gr.Textbox(label="Top-p Sampling")
186
+
187
+ with gr.Row():
188
+ with gr.Column():
189
+ gr.Markdown("## どの結果の続きが気になりますか?")
190
+ radio1 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="結果の選択")
191
+ output_text = gr.Textbox(label="初回の結果")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ gr.Markdown(f"## どの方法で続きを生成しますか?")
196
+ history = gr.State()
197
+ now_text = gr.TextArea(label="これまでの結果")
198
+ radio2 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="続き生成のデコード方法")
199
+ btn2 = gr.Button("続きを生成")
200
+ next_text = gr.TextArea(label="今回の生成結果")
201
+
202
+
203
+ # 2.2 イベント処理
204
+ btn1.click(fn=generate, inputs=[input_text, max_length, num_beams, p], outputs=[out1, out2, out3, out4])
205
+ out1.change(select_out1, inputs=[out1], outputs=[output_text, history, now_text])
206
+ radio1.change(select_out, inputs=[radio1, out1, out2, out3, out4], outputs=[output_text, history, now_text])
207
+ btn2.click(fn=generate_next, inputs=[history, radio2, max_length, num_beams, p], outputs=[history, now_text, next_text])
208
+
209
+ if __name__ == "__main__":
210
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==22.1.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==22.2.0
8
+ certifi==2022.12.7
9
+ charset-normalizer==2.1.1
10
+ click==8.1.3
11
+ contourpy==1.0.7
12
+ cycler==0.11.0
13
+ entrypoints==0.4
14
+ fastapi==0.89.1
15
+ ffmpy==0.3.0
16
+ filelock==3.9.0
17
+ fonttools==4.38.0
18
+ frozenlist==1.3.3
19
+ fsspec==2023.1.0
20
+ gradio==3.17.1
21
+ h11==0.14.0
22
+ httpcore==0.16.3
23
+ httpx==0.23.3
24
+ huggingface-hub==0.12.0
25
+ idna==3.4
26
+ Jinja2==3.1.2
27
+ jsonschema==4.17.3
28
+ kiwisolver==1.4.4
29
+ linkify-it-py==1.0.3
30
+ markdown-it-py==2.1.0
31
+ MarkupSafe==2.1.2
32
+ matplotlib==3.6.3
33
+ mdit-py-plugins==0.3.3
34
+ mdurl==0.1.2
35
+ multidict==6.0.4
36
+ numpy==1.24.2
37
+ orjson==3.8.5
38
+ packaging==23.0
39
+ pandas==1.5.3
40
+ Pillow==9.4.0
41
+ pycryptodome==3.17
42
+ pydantic==1.10.4
43
+ pydub==0.25.1
44
+ pyparsing==3.0.9
45
+ pyrsistent==0.19.3
46
+ python-dateutil==2.8.2
47
+ python-multipart==0.0.5
48
+ pytz==2022.7.1
49
+ PyYAML==6.0
50
+ regex==2022.10.31
51
+ requests==2.28.2
52
+ rfc3986==1.5.0
53
+ sentencepiece==0.1.97
54
+ six==1.16.0
55
+ sniffio==1.3.0
56
+ starlette==0.22.0
57
+ tokenizers==0.13.2
58
+ toolz==0.12.0
59
+ torch==1.13.1
60
+ tqdm==4.64.1
61
+ transformers==4.26.0
62
+ typing_extensions==4.4.0
63
+ uc-micro-py==1.0.1
64
+ urllib3==1.26.14
65
+ uvicorn==0.20.0
66
+ websockets==10.4
67
+ yarl==1.8.2