File size: 9,685 Bytes
ddbadbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
from transformers import T5Tokenizer, AutoModelForCausalLM, GenerationConfig

# 0. モデルとトークナイザーの定義
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
tokenizer.do_lower_case = True  # rinna/japanese-gpt2特有のハック
model = AutoModelForCausalLM.from_pretrained(
    "rinna/japanese-gpt2-small",
    pad_token_id=tokenizer.eos_token_id  # warningを避けるために、padにEOSトークンを割りあてる
    )

# 1. Gradioのコンポーネントのイベント処理用の関数の定義
def generate(text, max_length, num_beams, p):
    """初回のテキスト生成

    テキスト生成を行うが、デコード方法によって異なる結果になることを示すための処理を行う。
    指定されたパラメタを使って、異なる4つデコード方法を同時に出力する。

    Args:
        text: str
            Stateから取得(続きを生成するためのプロンプト)
        max_length: int
            Sliderから取得(全てのデコード方法に共通のパラメタ。生成する単語数)
        num_beams: int
            Sliderから取得(Beam Searchのパラメタ)
        p: int
            Sliderから取得(Top-p Samplingのパラメタ)
    
    Returns:
        tuple(str1, str2, str3)
            str1: State(生成結果を入出力の状態に反映)
            str2: TextArea(全文表示用のコンポーネントで使用)
            str3: TextArea(今回生成した文を表示するコンポーネントで使用)
    """
    # テキスト生成用のconfigクラスを使って、4パターンの設定を定義する。
    generate_config_list = [
        GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            num_beams=1,  # beam幅の設定、2以上ではbeam searchになる。
            do_sample=False  # Samplingの設定
        ),
        GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            num_beams=1,
            do_sample=True
        ),
        GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            num_beams=num_beams,
            do_sample=False
        ),
        GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_p=p  # Top-p Samplingのパラメタの設定
        )
    ]
    generated_texts = []

    inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")["input_ids"]
    for generate_config in generate_config_list:
        # テキスト生成
        output = model.generate(inputs, generation_config=generate_config)
        generated = tokenizer.decode(output[0], skip_special_tokens=True)
        # 読みやすくさの処理を行なって、リストに追加
        generated_texts.append("。\n".join(generated.replace(" ", "").split("。")))

    # gradioはtupleを想定している。これと同じ処理:return generated_texts[0], generated_texts[1], generated_texts[2]
    # pythonのタプルは「,」によって生成される。丸括弧は省略可能。参考:https://note.nkmk.me/python-function-return-multiple-values/
    return tuple(generated_texts)

def select_out1(out1):
    """out1が生成された時に、out1を後続の処理のデフォルト値に入力
    """
    return out1, out1, out1

def select_out(radio, out1, out2, out3, out4):
    """後続の処理に使用する、初回の処理結果を選択する
    """
    if radio == "1.Greedy":
        out = out1
    elif radio == "2.Sampling":
        out = out2
    elif radio == "3.Beam Search":
        out = out3
    else:
        out = out4
    return out, out, out

def generate_next(now_text, radio, max_length, num_beams, p):
    """続き生成

    これまで出力したテキストを入力して受け取り、続きを生成する。
    デコード方法を指定することができるが、そのパラメタは初回のテキスト生成と同じになる。

    Args:
        now_text: str
            Stateから取得(続きを生成するためのプロンプト)
        radio: str
            Radioから取得(使用するデコード方法の名前)
        max_length: int
            Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
        num_beams: int
            Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
        p: int
            Sliderから取得(初回のテキスト生成で使用した値をここでも使用)
    
    Returns:
        next_text: str
            State(生成結果を入出力の状態に反映)
        next_text: str
            TextArea(全文表示用のコンポーネントで使用)
        gen_text: str
            TextArea(今回生成した文を表示するコンポーネントで使用)
    """
    # デコード方法の指定に合わせて、cofingを定義
    if radio == "1.Greedy":
        generate_config = GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            num_beams=1,
            do_sample=False
        )
    elif radio == "2.Sampling":
        generate_config = GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            num_beams=1,
            do_sample=True
        )
    elif radio == "3.Beam Search":
        generate_config = GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            num_beams=num_beams,
            do_sample=False
        )
    else:
        generate_config = GenerationConfig(
            max_new_tokens=max_length,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_p=p
        )

    # テキスト生成
    inputs = tokenizer(now_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
    output = model.generate(inputs, generation_config=generate_config)
    generated = tokenizer.decode(output[0], skip_special_tokens=True)
    # 結果の整形処理
    next_text = "。\n".join(generated.replace(" ", "").split("。"))
    gen_text = next_text[len(now_text)+1:]  # 今回生成したテキストを抽出

    return next_text, next_text, gen_text

# 2. GradioによるUI/イベント処理の定義
with gr.Blocks() as demo:
    # 2.1. UI
    gr.Markdown('''
    # テキスト生成
    テキストを入力すると、4パターンのデコード方法でテキスト生成を実行します。
    ## 4つのパターン(入門編)
    1. Greedy: ビームサーチもサンプリングも行いません。毎回、最も確率の高い単語を選択します。
    2. Sampling: モデルによって与えられた語彙全体の確率分布に基づいて次の単語を選択します。
    3. Beam Search: 各タイムステップで複数の仮説を保持し、最終的に仮説ごとのシーケンス全体で最も高い確率を持つ仮説を選択します。
    4. Top-p Sampling: 2の方法に関して、確率の和がpになる最小の単語にフィルタリングすることで、確率が低い単語が選ばれる可能性を無くします。
    ''')

    with gr.Row():  # 行に分ける。なので、このブロック内にあるコンポーネントは横に並ぶ。
        with gr.Column():  # さらに列に分ける。なので、このブロック内にあるコンポーネントは縦に並ぶ。
            input_text = gr.Textbox(value="福岡のご飯は美味しい。", label="プロンプト")
            max_length = gr.Slider(100, 1000, step=100, value=100, label="生成するテキストの長さ")
            num_beams = gr.Slider(1, 10, step=1, value=6, label="beam幅")
            p = gr.Slider(0, 1, step=0.01, value=0.92, label="p")
            btn1 = gr.Button("4パターンで生成")
        
        with gr.Column():
            out1 = gr.Textbox(label="Greedy")
            out2 = gr.Textbox(label="Sampling")
            out3 = gr.Textbox(label="Beam Search")
            out4 = gr.Textbox(label="Top-p Sampling")

    with gr.Row():
        with gr.Column():
            gr.Markdown("## どの結果の続きが気になりますか?")
            radio1 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="結果の選択")
            output_text = gr.Textbox(label="初回の結果")
            
    with gr.Row():
        with gr.Column():
            gr.Markdown(f"## どの方法で続きを生成しますか?")
            history = gr.State()
            now_text = gr.TextArea(label="これまでの結果")
            radio2 = gr.Radio(choices=["1.Greedy", "2.Sampling", "3.Beam Search", "4.Top-p Sampling"], value="1.Greedy", label="続き生成のデコード方法")
            btn2 = gr.Button("続きを生成")
            next_text = gr.TextArea(label="今回の生成結果")
            

    # 2.2 イベント処理
    btn1.click(fn=generate, inputs=[input_text, max_length, num_beams, p], outputs=[out1, out2, out3, out4])
    out1.change(select_out1, inputs=[out1], outputs=[output_text, history, now_text])
    radio1.change(select_out, inputs=[radio1, out1, out2, out3, out4], outputs=[output_text, history, now_text])
    btn2.click(fn=generate_next, inputs=[history, radio2, max_length, num_beams, p], outputs=[history, now_text, next_text])

if __name__ == "__main__":
    demo.launch()