tugaa commited on
Commit
ad7e512
·
verified ·
1 Parent(s): 2816c7b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+ from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
5
+ import gradio as gr
6
+ import logging
7
+ from pathlib import Path
8
+ import json
9
+ from typing import Tuple # generate_response の型ヒントのために追加
10
+
11
+ # ragsys03.py から RAGSystem クラスと QueryResult をインポート
12
+ from ragsys03 import RAGSystem, QueryResult
13
+
14
+ # ロギング設定
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # グローバル変数
19
+ rag_system: RAGSystem = None
20
+ llm_pipeline = None
21
+
22
+ # RAGデータ保存用のベースディレクトリ
23
+ RAG_BASE_DIR = Path('rag_data')
24
+
25
+ def initialize_systems():
26
+ """
27
+ RAGシステムとLLMパイプラインを初期化します。
28
+ """
29
+ global rag_system, llm_pipeline
30
+
31
+ if rag_system is None:
32
+ logger.info("RAGSystemを初期化中...")
33
+ try:
34
+ # 最新のインデックスをロードするか、新しいディレクトリを作成する
35
+ rag_system = RAGSystem(
36
+ model_name='all-MiniLM-L6-v2', # SentenceTransformer モデル
37
+ index_type='ivf', # または 'flat'
38
+ n_clusters=100, # IVFの場合のクラスタ数
39
+ index_base_dir=RAG_BASE_DIR,
40
+ load_latest=True # 最新のインデックスを自動的にロード
41
+ )
42
+ logger.info("RAGSystemの初期化が完了しました。")
43
+
44
+ # 初期化時にインデックスが存在しない場合、ここで構築を促す
45
+ if rag_system.index is None or rag_system.index.ntotal == 0:
46
+ logger.warning("RAGシステムにインデックスがありません。文書をアップロードして「インデックスを構築」ボタンを押してください。")
47
+
48
+ except Exception as e:
49
+ logger.critical(f"RAGSystemの初期化中にエラーが発生しました: {e}")
50
+ rag_system = None # 初期化失敗時はNoneに設定
51
+ # Gradio UIでエラーメッセージを表示するための処理を考慮するか、
52
+ # 各アクションでrag_systemがNoneの場合の処理を堅牢にする
53
+ # ここでは単にログに出力し、後続の関数でNoneチェックを行う
54
+
55
+ if llm_pipeline is None:
56
+ logger.info("LLMパイプラインを初期化中: rinna/japanese-gpt-neox-3.6b-instruction-sft")
57
+ try:
58
+ # 量子化を無効にした設定
59
+ llm_pipeline = pipeline(
60
+ "text-generation",
61
+ model="rinna/japanese-gpt-neox-3.6b-instruction-sft",
62
+ tokenizer=AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False),
63
+ torch_dtype=torch.bfloat16, # CPUでも動くようにbfloat16のまま
64
+ device_map="auto", # autoのままでOK (GPUがなければCPUに自動的に割り当たる)
65
+ # model_kwargs={"quantization_config": quantization_config} を削除!
66
+ )
67
+ logger.info("LLMパイプラインの初期化が完了しました。")
68
+ except Exception as e:
69
+ logger.critical(f"LLMパイプラインの初期化中にエラーが発生しました: {e}")
70
+ llm_pipeline = None # 初期化失敗時はNoneに設定
71
+ # 同様に、エラー発生時はllm_pipelineをNoneのままにする
72
+
73
+ def generate_response(query: str, top_k: int, similarity_threshold: float) -> Tuple[str, str]:
74
+ """
75
+ RAGシステムとLLMを使用して応答を生成します。
76
+ Args:
77
+ query (str): ユーザーからのクエリ。
78
+ top_k (int): 検索する文書の数。
79
+ similarity_threshold (float): コサイン類似度の閾値。
80
+ Returns:
81
+ Tuple[str, str]: LLMの回答と、検索された文書のテキスト。
82
+ """
83
+ if rag_system is None or llm_pipeline is None:
84
+ return "システムが初期化されていません。アプリケーションを再起動してください。", ""
85
+
86
+ logger.info(f"クエリ処理開始: {query}")
87
+ query_result: QueryResult = rag_system.query(query, top_k=top_k, similarity_threshold=similarity_threshold)
88
+
89
+ llm_prompt = query_result['llm_prompt']
90
+ retrieved_docs = query_result['retrieved_documents']
91
+ search_time = query_result['search_time']
92
+
93
+ if not retrieved_docs:
94
+ llm_answer = "関連する情報が見つかりませんでした。別の表現で質問していただくか、より具体的な内容で質問してください。"
95
+ return llm_answer, "関連文書なし"
96
+
97
+ try:
98
+ outputs = llm_pipeline(
99
+ llm_prompt,
100
+ max_new_tokens=512,
101
+ do_sample=True,
102
+ temperature=0.7,
103
+ top_p=0.9,
104
+ repetition_penalty=1.1,
105
+ num_return_sequences=1,
106
+ pad_token_id=llm_pipeline.tokenizer.eos_token_id # pad_token_idを指定
107
+ )
108
+
109
+ llm_answer = outputs[0]['generated_text']
110
+ # プロンプト部分を除去して回答のみを抽出 (モデルの出力形式による)
111
+ if llm_answer.startswith(llm_prompt):
112
+ llm_answer = llm_answer[len(llm_prompt):].strip()
113
+
114
+ logger.info(f"LLM応答生成完了 (検索時間: {search_time:.4f}秒)")
115
+
116
+ retrieved_docs_text_formatted = "\n\n".join([f"**文書{i+1}:** {doc}" for i, doc in enumerate(retrieved_docs)])
117
+
118
+ return llm_answer, retrieved_docs_text_formatted
119
+
120
+ except Exception as e:
121
+ logger.error(f"LLMによる応答生成中にエラーが発生しました: {e}")
122
+ return f"応答生成中にエラーが発生しました: {str(e)}", ""
123
+
124
+ def upload_documents(file):
125
+ """
126
+ アップロードされたJSONファイルから文書をロードし、インデックスを再構築します。
127
+ """
128
+ global rag_system
129
+
130
+ if rag_system is None:
131
+ # RAGSystemが未初期化の場合、ここで初期化を試みる
132
+ # initialize_systems() は、Gradioの起動時に一度しか呼ばれないため、
133
+ # ここでrag_systemがNoneの場合、初期化に失敗している可能性がある。
134
+ # 代わりに、エラーメッセージを返す
135
+ return "エラー: RAGシステムが初期化されていません。アプリケーションを再起動してください。", ""
136
+
137
+ if file is None:
138
+ return "エラー: ファイルがアップロードされていません。", ""
139
+
140
+ uploaded_file_path = Path(file.name)
141
+ logger.info(f"アップロードされたファイル: {uploaded_file_path}")
142
+
143
+ try:
144
+ # ragsys03.py の load_documents メソッドを使用
145
+ rag_system.load_documents(uploaded_file_path)
146
+
147
+ # インデックス構築のメッセージを表示し、ユーザーにボタンを押してもらう
148
+ return (
149
+ f"'{uploaded_file_path.name}' から {len(rag_system.documents)} 件の文書を正常にロードしました。\n"
150
+ "続けて「インデックスを構築」ボタンを押してください。",
151
+ "文書がロードされました。インデックス構築が必要です。"
152
+ )
153
+ except Exception as e:
154
+ logger.error(f"文書のロード中にエラーが発生しました: {e}")
155
+ return f"文書のロード中にエラーが発生しました: {str(e)}", ""
156
+
157
+ def build_index_action():
158
+ """
159
+ RAGインデックスを構築します。
160
+ """
161
+ global rag_system
162
+ if rag_system is None:
163
+ return "エラー: RAGシステムが初期化されていません。", ""
164
+
165
+ if not rag_system.documents:
166
+ return "エラー: インデックス作成用の文書がロードされていません。", ""
167
+
168
+ try:
169
+ # force_rebuild=True で常に再構築
170
+ rag_system.build_index(force_rebuild=True)
171
+ stats = rag_system.get_stats()
172
+ return (
173
+ f"インデックスが正常に構築されました。\n"
174
+ f"総文書数: {stats.get('documents_count', 'N/A')}\n"
175
+ f"インデックスタイプ: {stats.get('index_type', 'N/A')}\n"
176
+ f"インデックスディレクトリ: {stats.get('current_index_dir', 'N/A')}\n"
177
+ f"ファイルサイズ: {stats.get('index_file_size_mb', 'N/A')} MB"
178
+ ), "インデックス構築完了"
179
+ except Exception as e:
180
+ logger.error(f"インデックスの構築中にエラーが発生しました: {e}")
181
+ return f"エラー: インデックスの構築中に問題が発生しました - {str(e)}", ""
182
+
183
+ def get_rag_stats_action():
184
+ """RAGシステムの統計情報を表示します。"""
185
+ global rag_system
186
+ if rag_system is None:
187
+ return "RAGシステムはまだ初期化されていません。"
188
+ try:
189
+ stats = rag_system.get_stats()
190
+ stats_str = "\n".join([f"{k}: {v}" for k, v in stats.items()])
191
+ return f"RAGシステムの状態:\n{stats_str}"
192
+ except Exception as e:
193
+ logger.error(f"RAG統計情報の取得中にエラーが発生しました: {e}")
194
+ return f"統計情報の取得中にエラーが発生しました: {str(e)}"
195
+
196
+ # Gradio UI の構築
197
+ with gr.Blocks() as demo:
198
+ gr.Markdown(
199
+ """
200
+ # RAG (Retrieval Augmented Generation) デモ
201
+ rinna/japanese-gpt-neox-3.6b-instruction-sft と FAISS を利用したRAGシステムです。
202
+ JSONファイルをアップロードして独自の知識ベースを構築し、質問をすることができます。
203
+
204
+ ## 使用方法
205
+ 1. 「文書をアップロード」セクションで、`documents` キーに文字列のリストを持つJSONファイルをアップロードします。
206
+ 2. 「インデックスを構築」ボタンをクリックして、アップロードされた文書から検索インデックスを作成します。
207
+ 3. 「RAG質問」セクションで質問を入力し、「質問を送信」ボタンをクリックします。
208
+ """
209
+ )
210
+
211
+ with gr.Tab("RAG質問"):
212
+ gr.Markdown("### RAG質問")
213
+ with gr.Row():
214
+ query_input = gr.Textbox(label="質問を入力してください", placeholder="RAGシステムの主な利点は何ですか?", lines=2)
215
+ submit_button = gr.Button("質問を送信")
216
+
217
+ with gr.Row():
218
+ top_k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="取得文書数 (top_k)")
219
+ similarity_threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="類似度閾値 (0.0-1.0)")
220
+
221
+ llm_output = gr.Textbox(label="LLMの回答", lines=10, interactive=False)
222
+ retrieved_docs_output = gr.Markdown("検索された関連文書", visible=True)
223
+
224
+ with gr.Tab("文書管理"):
225
+ gr.Markdown("### 文書とインデックス管理")
226
+ with gr.Row():
227
+ file_upload_button = gr.File(label="文書JSONファイルをアップロード", file_types=[".json"])
228
+ upload_status = gr.Textbox(label="アップロードステータス", interactive=False)
229
+
230
+ with gr.Row():
231
+ build_index_button = gr.Button("インデックスを構築")
232
+ build_index_status = gr.Textbox(label="インデックス構築ステータス", interactive=False)
233
+
234
+ gr.Markdown("### RAGシステム情報")
235
+ get_stats_button = gr.Button("RAGシステムの状態を表示")
236
+ rag_stats_output = gr.Textbox(label="RAGシステム情報", interactive=False, lines=5)
237
+
238
+
239
+ # イベントハンドラの登録
240
+ submit_button.click(
241
+ fn=generate_response,
242
+ inputs=[query_input, top_k_slider, similarity_threshold_slider],
243
+ outputs=[llm_output, retrieved_docs_output]
244
+ )
245
+
246
+ file_upload_button.upload(
247
+ fn=upload_documents,
248
+ inputs=file_upload_button,
249
+ outputs=[upload_status, retrieved_docs_output] # アップロード結果と関連文書表示エリアを更新
250
+ )
251
+
252
+ build_index_button.click(
253
+ fn=build_index_action,
254
+ inputs=[],
255
+ outputs=[build_index_status, retrieved_docs_output] # 構築結果と関連文書表示エリアを更新
256
+ )
257
+
258
+ get_stats_button.click(
259
+ fn=get_rag_stats_action,
260
+ inputs=[],
261
+ outputs=rag_stats_output
262
+ )
263
+
264
+ # システムの初期化をアプリケーション起動時に実行
265
+ # initialize_systems() の中でエラーが発生した場合、rag_system や llm_pipeline が None になる
266
+ # これにより、後続の関数呼び出しでこれらの変数がNoneであることのチェックが機能する
267
+ initialize_systems()
268
+
269
+ if __name__ == "__main__":
270
+ demo.launch(debug=True, share=True)