nekoniii3 commited on
Commit
d80203b
1 Parent(s): 1f8bb86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +503 -0
app.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import os
4
+ import datetime
5
+ # from zoneinfo import ZoneInfo
6
+ from openai import OpenAI
7
+
8
+ from openai.types.beta.threads.runs import (
9
+ ToolCallsStepDetails,
10
+ )
11
+
12
+ # GPT用設定
13
+ SYS_PROMPT_DEFAULT = "あなたは優秀なアシスタントです。質問をされた場合は、質問に答えるコードを作成して実行します。回答は日本語でお願いします。"
14
+ DUMMY = "********************"
15
+ file_format = {".txt", ".csv", ".pdf"}
16
+
17
+ # 各種出力フォルダ
18
+ IMG_FOLDER = "sample_data"
19
+ ANT_FOLDER = "sample_data"
20
+
21
+ # 各種メッセージ
22
+ IMG_MSG = "(画像ファイルを追加しました。送信ボタンの下に表示されています。)"
23
+ ANT_MSG = "(下部の[出力ファイル]にファイルを追加しました。)"
24
+
25
+ # 環境変数情報
26
+ # os.environ["OPENAI_API_KEY"] = ""
27
+ os.environ["ASSIST_ID"] = "asst_KHpzJRBEgONhDf6cIpxr1Avt"
28
+
29
+ # 各種設定値
30
+ MAX_TRIAL = 15 # メッセージ取得最大試行数
31
+ INTER_SEC = 3 # 試行間隔(秒)
32
+
33
+ # サンプル用情報
34
+ examples = ["sample_data/東京都年別人口.csv", "sample_data/練馬区年齢別人口.csv"]
35
+ example_toid = {"東京都年別人口.csv" : "file-GOEk4X4WpU5gBJAuHCMtiJrn"
36
+ , "練馬区年齢別人口.csv" : "file-YAFPMMqG3Zl5DRx5hTLjCfFa"}
37
+
38
+ # file_id = "file-0Ly64DA2jzE9mOFYayOKJJK0"
39
+ # file_id = "file-aVnVcpEVpsy77xQ8SlTp1WoX" # ライ麦
40
+
41
+ # file_id = "file-HFCaJbf3k7j0fhBqh1Rwf2VV" # 練馬区
42
+
43
+ # コード出力用
44
+ code_mode = {'ON': True, 'OFF': False}
45
+
46
+ def set_state(openai_key, sys_prompt, code_output, state):
47
+ """ 設定タブの情報をセッションに保存する関数 """
48
+
49
+ state["openai_key"] = openai_key
50
+ state["system_prompt"] = sys_prompt
51
+ state["code_mode"] = code_mode[code_output]
52
+
53
+ return state
54
+
55
+
56
+ def init(state, text, file):
57
+ """ 入力チェックを行う関数
58
+ ※ここで例外を起こすと入力できなくなるので次の関数でエラーにする """
59
+
60
+ err_msg = ""
61
+ file_id = None
62
+
63
+ # if state["openai_key"] == "" or state["openai_key"] is None:
64
+
65
+ # # OpenAI API Key未入力
66
+ # err_msg = "OpenAI API Keyを入力してください。(設定タブ)"
67
+
68
+ if not text:
69
+
70
+ # テキスト未入力
71
+ err_msg = "テキストを入力して下さい。"
72
+
73
+ return state, text, file, file_id, err_msg
74
+
75
+ elif file:
76
+
77
+ # 入力画像のファイル形式チェック
78
+ root, ext = os.path.splitext(file)
79
+
80
+ if ext not in file_format:
81
+
82
+ # ファイル形式チェック
83
+ err_msg = "指定した形式のファイルをアップしてください。(注意事項タブに記載)"
84
+
85
+ return state, text, gr.Image(value=None,type="filepath", interactive=False), file_id, err_msg
86
+
87
+ if state["thread_id"] is None:
88
+
89
+ # 初めてなら初期処理をする
90
+ client = OpenAI()
91
+
92
+ # assistant = client.beta.assistants.create(
93
+ # name="codeinter_test",
94
+ # instructions=state["system_prompt"],
95
+ # # model="gpt-4-1106-preview",
96
+ # model="gpt-3.5-turbo-1106",
97
+ # tools=[{"type": "code_interpreter"}]
98
+ # )
99
+
100
+ # print(assistant.id)
101
+
102
+ thread = client.beta.threads.create()
103
+
104
+ state["client"] = client
105
+ # state["assistant_id"] = assistant.id
106
+ state["assistant_id"] = os.environ["ASSIST_ID"]
107
+ state["thread_id"] = thread.id
108
+
109
+ if file:
110
+
111
+ # ファイル名取得
112
+ basename = os.path.basename(file)
113
+
114
+ if example_toid.get(basename):
115
+
116
+ # サンプルの場合は用意したIDをセット
117
+ file_id = example_toid.get(basename)
118
+
119
+ else:
120
+
121
+ # ファイルのアップ
122
+ # file_response = client.files.create(
123
+ # purpose="assistants",
124
+ # file=open(file,"rb"),
125
+ # )
126
+
127
+ # if file_response.status != 'processed':
128
+
129
+ # # 失敗時
130
+ # err_msg = "ファイルのアップロードに失敗しました"
131
+
132
+ # else
133
+ # # ファイルのIDをセット
134
+ # file_id = file_response.id
135
+
136
+ # file_id = "file-0Ly64DA2jzE9mOFYayOKJJK0"
137
+ # file_id = "file-aVnVcpEVpsy77xQ8SlTp1WoX" # ライ麦
138
+
139
+ # file_id = "file-HFCaJbf3k7j0fhBqh1Rwf2VV" # 練馬区
140
+ file_id = ""
141
+
142
+ # print(file_id)
143
+
144
+ return state, text, file, file_id, err_msg
145
+
146
+ def raise_exception(err_msg):
147
+ """ エラーの場合例外を起こす関数 """
148
+
149
+ if err_msg != "":
150
+ raise Exception()
151
+
152
+ return
153
+
154
+
155
+ def add_history(history, text, file_id):
156
+ """ Chat履歴"history"に追加を行う関数 """
157
+
158
+ # print("前:")
159
+ # print(history)
160
+
161
+ err_msg = ""
162
+ new_row_flg = False
163
+
164
+ # 新しい行を追加するか判定
165
+ # if len(history) == 0:
166
+
167
+ # new_row_flg = True
168
+
169
+ # elif history[-1][0] is not None:
170
+
171
+ # # 前回がアシスタントでない場合も追加
172
+ # new_row_flg = True
173
+
174
+ new_row_flg = True
175
+
176
+ if file_id is None or file_id == "":
177
+
178
+ if new_row_flg:
179
+
180
+ # テキストだけの場合そのまま追加
181
+ history = history + [(text, None)]
182
+ else:
183
+ history[-1][0] = text
184
+
185
+ elif file_id is not None:
186
+
187
+ if new_row_flg:
188
+
189
+ # ファイルがあればファイルIDとテキストを追加
190
+ history = history + [("file:" + file_id, DUMMY)]
191
+ history = history + [(text, None)]
192
+
193
+ else:
194
+ history[-1][0] = "file:" + file_id
195
+ history = history + [(text, None)]
196
+
197
+ print(history)
198
+
199
+ # テキストだけ初期化
200
+ new_text = gr.Textbox(value="", interactive=True)
201
+
202
+ return history, new_text, err_msg
203
+
204
+
205
+ def bot(state, history, file_id):
206
+
207
+ err_msg = ""
208
+ image_file = None
209
+ ant_file = None
210
+ # new_row_flg = False
211
+
212
+ # セッション情報取得
213
+ system_prompt = state["system_prompt"]
214
+ client = state["client"]
215
+ assistant_id = state["assistant_id"]
216
+ thread_id = state["thread_id"]
217
+ msg_id = state["last_msg_id"]
218
+ code_mode = state["code_mode"]
219
+
220
+ print("system_prompt")
221
+
222
+ if file_id is None or file_id == "":
223
+
224
+ # ファイルがない場合
225
+ message = client.beta.threads.messages.create(
226
+ thread_id=thread_id,
227
+ role="user",
228
+ content=history[-1][0],
229
+ )
230
+ else:
231
+
232
+ # ファイルがあるときはIDをセット
233
+ message = client.beta.threads.messages.create(
234
+ thread_id=thread_id,
235
+ role="user",
236
+ content=history[-1][0],
237
+ file_ids=[file_id]
238
+ )
239
+
240
+ # スレッド実行
241
+ run = client.beta.threads.runs.create(
242
+ thread_id=thread_id,
243
+ assistant_id=assistant_id,
244
+ instructions=system_prompt
245
+ )
246
+
247
+ # "completed"となるまで繰り返す(指定秒おき)
248
+ for i in range(0, MAX_TRIAL, 1):
249
+
250
+ if i > 0:
251
+ time.sleep(INTER_SEC)
252
+
253
+ # メッセージ受け取り
254
+ run = client.beta.threads.runs.retrieve(
255
+ thread_id=thread_id,
256
+ run_id=run.id
257
+ )
258
+
259
+ # 前回のメッセージより後を昇順で取り出す
260
+ messages = client.beta.threads.messages.list(
261
+ thread_id=thread_id,
262
+ after=msg_id,
263
+ order="asc"
264
+ )
265
+
266
+ print(msg_id)
267
+ print(messages.data)
268
+
269
+ # messageを取り出す
270
+ for msg in messages:
271
+
272
+ msg_id = msg.id
273
+
274
+ if msg.role == "assistant":
275
+
276
+ for content in msg.content:
277
+
278
+ res_text = ""
279
+ file_id = ""
280
+ ant_file = None
281
+
282
+ cont_dict = content.model_dump() # 辞書型に変換
283
+
284
+ ct_image_file = cont_dict.get("image_file")
285
+
286
+ if ct_image_file:
287
+
288
+ # imageファイルがあるならIDセット
289
+ res_file_id = ct_image_file.get("file_id")
290
+
291
+ # ファイルをダウンロード
292
+ image_file = file_download(client, res_file_id, IMG_FOLDER , ".png")
293
+
294
+ if image_file is None:
295
+
296
+ err_msg = "ファイルのダウンロードに失敗しました。"
297
+
298
+ else:
299
+
300
+ print("画像ファイル追加")
301
+
302
+ res_text = IMG_MSG
303
+
304
+ history = history + [[None, res_text]]
305
+
306
+ else:
307
+
308
+ # 画像がないならテキスト取得
309
+ res_text = cont_dict["text"].get("value")
310
+
311
+ # 注釈(参照ファイル)ががある場合取得
312
+ if len(cont_dict.get("text").get("annotations")) > 0:
313
+
314
+ ct_ant = cont_dict.get("text").get("annotations")
315
+
316
+ if ct_ant[0].get("file_path") is not None:
317
+
318
+ ant_file_id = ct_ant[0].get("file_path").get("file_id")
319
+
320
+
321
+ if ct_ant[0].get("text") is not None:
322
+
323
+ # 拡張子取得
324
+ ext = "." + ct_ant[0].get("text")[ct_ant[0].get("text").rfind('.') + 1:]
325
+
326
+ # ファイルダウンロード
327
+ ant_file = file_download(client, ant_file_id, ANT_FOLDER, ext)
328
+
329
+ if ant_file is None:
330
+
331
+ err_msg = "参照ファイルのダウンロードに失敗しました。"
332
+
333
+ else:
334
+
335
+ # 参照ファイルがある旨のメッセージを追加
336
+ res_text = res_text + "\n\n" + ANT_MSG
337
+
338
+ print(res_text)
339
+
340
+ if res_text != "":
341
+
342
+ # Chat画面更新
343
+ if history[-1][1] is not None:
344
+
345
+ # 新しい行を追加
346
+ history = history + [[None, res_text]]
347
+ else:
348
+
349
+ history[-1][1] = res_text
350
+
351
+ yield history, image_file, ant_file, err_msg
352
+
353
+ print(run.status)
354
+
355
+ state["last_msg_id"] = msg_id
356
+
357
+ # 完了なら終了
358
+ if run.status == "completed":
359
+
360
+ if not code_mode:
361
+ break
362
+ else:
363
+
364
+ # コードモードがONでコードがあれば取得
365
+ run_steps = client.beta.threads.runs.steps.list(
366
+ thread_id=thread_id, run_id=run.id
367
+ )
368
+
369
+ input_code = get_code(run_steps)
370
+
371
+ if input_code != "":
372
+
373
+ input_code = "[input_code]\n\n" + input_code
374
+
375
+ print(input_code)
376
+
377
+ # コードを追加
378
+ history = history + [[None, input_code]]
379
+
380
+ yield history, image_file, ant_file, err_msg
381
+
382
+ break
383
+
384
+ if run.status == "failed":
385
+
386
+ # エラーとして終了
387
+ err_msg = "※メッセージ取得に失敗しました。"
388
+ return history, image_file, ant_file, err_msg
389
+
390
+ if i == MAX_TRIAL:
391
+
392
+ # エラーとして終了
393
+ err_msg = "※メッセージ取得の際にタイムアウトしました。"
394
+ return history, image_file, ant_file, err_msg
395
+
396
+ def get_code(run_steps):
397
+
398
+ input_code = ""
399
+ print("run_steps")
400
+ print(run_steps)
401
+
402
+ for data in run_steps.data:
403
+
404
+ if isinstance(data.step_details, ToolCallsStepDetails):
405
+
406
+ # print(data.step_details)
407
+ input_code = data.step_details.tool_calls[0].code_interpreter.input
408
+
409
+ return input_code
410
+
411
+ def file_download(client, file_id, folder, ext):
412
+ """ OpenAIからファイルをダウンロードしてパスを返す """
413
+ api_response = client.files.with_raw_response.retrieve_content(file_id)
414
+
415
+ if api_response.status_code == 200:
416
+
417
+ content = api_response.content
418
+
419
+ file_path = folder + "/" + file_id + ext
420
+
421
+ with open(file_path, 'wb') as f:
422
+ f.write(content)
423
+
424
+ return file_path
425
+
426
+ else:
427
+ return None
428
+
429
+
430
+ def finally_proc():
431
+
432
+ os.environ["OPENAI_API_KEY"] = ""
433
+
434
+ # new_text = gr.Textbox(value="", interactive=True)
435
+ new_up_file = gr.File(value=None, interactive = True)
436
+ new_file_id = gr.Textbox(value="")
437
+
438
+ return new_up_file, new_file_id
439
+
440
+
441
+ with gr.Blocks() as demo:
442
+
443
+ gr.Markdown("<h2> </h2>")
444
+
445
+ # セッションの宣言
446
+ state = gr.State({
447
+ "system_prompt": SYS_PROMPT_DEFAULT,
448
+ "openai_key" : None,
449
+ "code_mode" : True, # テスト中
450
+ "client" : None,
451
+ "assistant_id" : None,
452
+ "thread_id" : None,
453
+ "last_msg_id" : ""
454
+ })
455
+
456
+ with gr.Tab("GPT-4V 画像入力対応チャット") as chat:
457
+
458
+ # 各コンポーネント定義
459
+ chatbot = gr.Chatbot(label="チャット画面")
460
+ text_msg = gr.Textbox(label="テキスト")
461
+ # image=gr.Image(type="filepath")
462
+ up_file = gr.File(label="ファイルアップロード", type="filepath",interactive = True)
463
+ gr.Examples(label="サンプルデータ", examples=examples, inputs=[up_file])
464
+ with gr.Row():
465
+ btn = gr.Button(value="送信")
466
+ # btn_download = gr.Button(value="画像のダウンロード") # 保留中
467
+ btn_clear = gr.ClearButton(value="リセット", components=[chatbot, text_msg, up_file, state])
468
+ sys_msg = gr.Textbox(label="システムメッセージ")
469
+ result_image = gr.Image(label="出力画像", type="filepath", interactive = False)
470
+ result_file = gr.File(label="出力ファイル", type="filepath",interactive = False)
471
+
472
+ # GPT用
473
+ file_id = gr.Textbox(visible=False)
474
+
475
+ # 送信ボタンクリック時の処理
476
+ bc = btn.click(init, [state, text_msg, up_file], [state, text_msg, up_file, file_id, sys_msg], queue=False).success(
477
+ raise_exception, sys_msg, None).success(
478
+ add_history, [chatbot, text_msg, file_id], [chatbot, text_msg, sys_msg], queue=False).success(
479
+ bot, [state, chatbot, file_id],[chatbot, result_image, result_file, sys_msg]).then(
480
+ finally_proc, None, [up_file, file_id], queue=False
481
+ )
482
+
483
+ # クリア時でもOpenAIKeyとプロンプトは残す
484
+ btn_clear.click(lambda:chat.select(set_state, [openai_key, system_prompt, state], state), None, None)
485
+
486
+ # up_file.upload(lambda:up_file.value, None, sys_msg)
487
+
488
+ # テキスト入力Enter時の処理
489
+ # txt_msg = text_msg.submit(respond, inputs=[text_msg, image, chatbot], outputs=[text_msg, image, chatbot])
490
+
491
+ with gr.Tab("設定") as set:
492
+ openai_key = gr.Textbox(label="OpenAI API Key")
493
+ # language = gr.Dropdown(choices=["Japanese", "English"], value = "Japanese", label="Language", interactive = True)
494
+ system_prompt = gr.Textbox(value = SYS_PROMPT_DEFAULT,lines = 5, label="Custom instructions", interactive = True)
495
+ # Enter不使用
496
+ code_output = gr.Dropdown(label="コード出力", choices=["OFF", "ON"], value = "ON", interactive = True)
497
+
498
+ # 設定タブからChatタブに戻った時の処理
499
+ chat.select(set_state, [openai_key, system_prompt, code_output, state], state)
500
+
501
+ demo.queue()
502
+
503
+ demo.launch(debug=True)