Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import json
|
| 9 |
+
from typing import Tuple # generate_response の型ヒントのために追加
|
| 10 |
+
|
| 11 |
+
# ragsys03.py から RAGSystem クラスと QueryResult をインポート
|
| 12 |
+
from ragsys03 import RAGSystem, QueryResult
|
| 13 |
+
|
| 14 |
+
# ロギング設定
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# グローバル変数
|
| 19 |
+
rag_system: RAGSystem = None
|
| 20 |
+
llm_pipeline = None
|
| 21 |
+
|
| 22 |
+
# RAGデータ保存用のベースディレクトリ
|
| 23 |
+
RAG_BASE_DIR = Path('rag_data')
|
| 24 |
+
|
| 25 |
+
def initialize_systems():
|
| 26 |
+
"""
|
| 27 |
+
RAGシステムとLLMパイプラインを初期化します。
|
| 28 |
+
"""
|
| 29 |
+
global rag_system, llm_pipeline
|
| 30 |
+
|
| 31 |
+
if rag_system is None:
|
| 32 |
+
logger.info("RAGSystemを初期化中...")
|
| 33 |
+
try:
|
| 34 |
+
# 最新のインデックスをロードするか、新しいディレクトリを作成する
|
| 35 |
+
rag_system = RAGSystem(
|
| 36 |
+
model_name='all-MiniLM-L6-v2', # SentenceTransformer モデル
|
| 37 |
+
index_type='ivf', # または 'flat'
|
| 38 |
+
n_clusters=100, # IVFの場合のクラスタ数
|
| 39 |
+
index_base_dir=RAG_BASE_DIR,
|
| 40 |
+
load_latest=True # 最新のインデックスを自動的にロード
|
| 41 |
+
)
|
| 42 |
+
logger.info("RAGSystemの初期化が完了しました。")
|
| 43 |
+
|
| 44 |
+
# 初期化時にインデックスが存在しない場合、ここで構築を促す
|
| 45 |
+
if rag_system.index is None or rag_system.index.ntotal == 0:
|
| 46 |
+
logger.warning("RAGシステムにインデックスがありません。文書をアップロードして「インデックスを構築」ボタンを押してください。")
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.critical(f"RAGSystemの初期化中にエラーが発生しました: {e}")
|
| 50 |
+
rag_system = None # 初期化失敗時はNoneに設定
|
| 51 |
+
# Gradio UIでエラーメッセージを表示するための処理を考慮するか、
|
| 52 |
+
# 各アクションでrag_systemがNoneの場合の処理を堅牢にする
|
| 53 |
+
# ここでは単にログに出力し、後続の関数でNoneチェックを行う
|
| 54 |
+
|
| 55 |
+
if llm_pipeline is None:
|
| 56 |
+
logger.info("LLMパイプラインを初期化中: rinna/japanese-gpt-neox-3.6b-instruction-sft")
|
| 57 |
+
try:
|
| 58 |
+
# 量子化を無効にした設定
|
| 59 |
+
llm_pipeline = pipeline(
|
| 60 |
+
"text-generation",
|
| 61 |
+
model="rinna/japanese-gpt-neox-3.6b-instruction-sft",
|
| 62 |
+
tokenizer=AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False),
|
| 63 |
+
torch_dtype=torch.bfloat16, # CPUでも動くようにbfloat16のまま
|
| 64 |
+
device_map="auto", # autoのままでOK (GPUがなければCPUに自動的に割り当たる)
|
| 65 |
+
# model_kwargs={"quantization_config": quantization_config} を削除!
|
| 66 |
+
)
|
| 67 |
+
logger.info("LLMパイプラインの初期化が完了しました。")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.critical(f"LLMパイプラインの初期化中にエラーが発生しました: {e}")
|
| 70 |
+
llm_pipeline = None # 初期化失敗時はNoneに設定
|
| 71 |
+
# 同様に、エラー発生時はllm_pipelineをNoneのままにする
|
| 72 |
+
|
| 73 |
+
def generate_response(query: str, top_k: int, similarity_threshold: float) -> Tuple[str, str]:
|
| 74 |
+
"""
|
| 75 |
+
RAGシステムとLLMを使用して応答を生成します。
|
| 76 |
+
Args:
|
| 77 |
+
query (str): ユーザーからのクエリ。
|
| 78 |
+
top_k (int): 検索する文書の数。
|
| 79 |
+
similarity_threshold (float): コサイン類似度の閾値。
|
| 80 |
+
Returns:
|
| 81 |
+
Tuple[str, str]: LLMの回答と、検索された文書のテキスト。
|
| 82 |
+
"""
|
| 83 |
+
if rag_system is None or llm_pipeline is None:
|
| 84 |
+
return "システムが初期化されていません。アプリケーションを再起動してください。", ""
|
| 85 |
+
|
| 86 |
+
logger.info(f"クエリ処理開始: {query}")
|
| 87 |
+
query_result: QueryResult = rag_system.query(query, top_k=top_k, similarity_threshold=similarity_threshold)
|
| 88 |
+
|
| 89 |
+
llm_prompt = query_result['llm_prompt']
|
| 90 |
+
retrieved_docs = query_result['retrieved_documents']
|
| 91 |
+
search_time = query_result['search_time']
|
| 92 |
+
|
| 93 |
+
if not retrieved_docs:
|
| 94 |
+
llm_answer = "関連する情報が見つかりませんでした。別の表現で質問していただくか、より具体的な内容で質問してください。"
|
| 95 |
+
return llm_answer, "関連文書なし"
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
outputs = llm_pipeline(
|
| 99 |
+
llm_prompt,
|
| 100 |
+
max_new_tokens=512,
|
| 101 |
+
do_sample=True,
|
| 102 |
+
temperature=0.7,
|
| 103 |
+
top_p=0.9,
|
| 104 |
+
repetition_penalty=1.1,
|
| 105 |
+
num_return_sequences=1,
|
| 106 |
+
pad_token_id=llm_pipeline.tokenizer.eos_token_id # pad_token_idを指定
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
llm_answer = outputs[0]['generated_text']
|
| 110 |
+
# プロンプト部分を除去して回答のみを抽出 (モデルの出力形式による)
|
| 111 |
+
if llm_answer.startswith(llm_prompt):
|
| 112 |
+
llm_answer = llm_answer[len(llm_prompt):].strip()
|
| 113 |
+
|
| 114 |
+
logger.info(f"LLM応答生成完了 (検索時間: {search_time:.4f}秒)")
|
| 115 |
+
|
| 116 |
+
retrieved_docs_text_formatted = "\n\n".join([f"**文書{i+1}:** {doc}" for i, doc in enumerate(retrieved_docs)])
|
| 117 |
+
|
| 118 |
+
return llm_answer, retrieved_docs_text_formatted
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"LLMによる応答生成中にエラーが発生しました: {e}")
|
| 122 |
+
return f"応答生成中にエラーが発生しました: {str(e)}", ""
|
| 123 |
+
|
| 124 |
+
def upload_documents(file):
|
| 125 |
+
"""
|
| 126 |
+
アップロードされたJSONファイルから文書をロードし、インデックスを再構築します。
|
| 127 |
+
"""
|
| 128 |
+
global rag_system
|
| 129 |
+
|
| 130 |
+
if rag_system is None:
|
| 131 |
+
# RAGSystemが未初期化の場合、ここで初期化を試みる
|
| 132 |
+
# initialize_systems() は、Gradioの起動時に一度しか呼ばれないため、
|
| 133 |
+
# ここでrag_systemがNoneの場合、初期化に失敗している可能性がある。
|
| 134 |
+
# 代わりに、エラーメッセージを返す
|
| 135 |
+
return "エラー: RAGシステムが初期化されていません。アプリケーションを再起動してください。", ""
|
| 136 |
+
|
| 137 |
+
if file is None:
|
| 138 |
+
return "エラー: ファイルがアップロードされていません。", ""
|
| 139 |
+
|
| 140 |
+
uploaded_file_path = Path(file.name)
|
| 141 |
+
logger.info(f"アップロードされたファイル: {uploaded_file_path}")
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
# ragsys03.py の load_documents メソッドを使用
|
| 145 |
+
rag_system.load_documents(uploaded_file_path)
|
| 146 |
+
|
| 147 |
+
# インデックス構築のメッセージを表示し、ユーザーにボタンを押してもらう
|
| 148 |
+
return (
|
| 149 |
+
f"'{uploaded_file_path.name}' から {len(rag_system.documents)} 件の文書を正常にロードしました。\n"
|
| 150 |
+
"続けて「インデックスを構築」ボタンを押してください。",
|
| 151 |
+
"文書がロードされました。インデックス構築が必要です。"
|
| 152 |
+
)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.error(f"文書のロード中にエラーが発生しました: {e}")
|
| 155 |
+
return f"文書のロード中にエラーが発生しました: {str(e)}", ""
|
| 156 |
+
|
| 157 |
+
def build_index_action():
|
| 158 |
+
"""
|
| 159 |
+
RAGインデックスを構築します。
|
| 160 |
+
"""
|
| 161 |
+
global rag_system
|
| 162 |
+
if rag_system is None:
|
| 163 |
+
return "エラー: RAGシステムが初期化されていません。", ""
|
| 164 |
+
|
| 165 |
+
if not rag_system.documents:
|
| 166 |
+
return "エラー: インデックス作成用の文書がロードされていません。", ""
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# force_rebuild=True で常に再構築
|
| 170 |
+
rag_system.build_index(force_rebuild=True)
|
| 171 |
+
stats = rag_system.get_stats()
|
| 172 |
+
return (
|
| 173 |
+
f"インデックスが正常に構築されました。\n"
|
| 174 |
+
f"総文書数: {stats.get('documents_count', 'N/A')}\n"
|
| 175 |
+
f"インデックスタイプ: {stats.get('index_type', 'N/A')}\n"
|
| 176 |
+
f"インデックスディレクトリ: {stats.get('current_index_dir', 'N/A')}\n"
|
| 177 |
+
f"ファイルサイズ: {stats.get('index_file_size_mb', 'N/A')} MB"
|
| 178 |
+
), "インデックス構築完了"
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"インデックスの構築中にエラーが発生しました: {e}")
|
| 181 |
+
return f"エラー: インデックスの構築中に問題が発生しました - {str(e)}", ""
|
| 182 |
+
|
| 183 |
+
def get_rag_stats_action():
|
| 184 |
+
"""RAGシステムの統計情報を表示します。"""
|
| 185 |
+
global rag_system
|
| 186 |
+
if rag_system is None:
|
| 187 |
+
return "RAGシステムはまだ初期化されていません。"
|
| 188 |
+
try:
|
| 189 |
+
stats = rag_system.get_stats()
|
| 190 |
+
stats_str = "\n".join([f"{k}: {v}" for k, v in stats.items()])
|
| 191 |
+
return f"RAGシステムの状態:\n{stats_str}"
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"RAG統計情報の取得中にエラーが発生しました: {e}")
|
| 194 |
+
return f"統計情報の取得中にエラーが発生しました: {str(e)}"
|
| 195 |
+
|
| 196 |
+
# Gradio UI の構築
|
| 197 |
+
with gr.Blocks() as demo:
|
| 198 |
+
gr.Markdown(
|
| 199 |
+
"""
|
| 200 |
+
# RAG (Retrieval Augmented Generation) デモ
|
| 201 |
+
rinna/japanese-gpt-neox-3.6b-instruction-sft と FAISS を利用したRAGシステムです。
|
| 202 |
+
JSONファイルをアップロードして独自の知識ベースを構築し、質問をすることができます。
|
| 203 |
+
|
| 204 |
+
## 使用方法
|
| 205 |
+
1. 「文書をアップロード」セクションで、`documents` キーに文字列のリストを持つJSONファイルをアップロードします。
|
| 206 |
+
2. 「インデックスを構築」ボタンをクリックして、アップロードされた文書から検索インデックスを作成します。
|
| 207 |
+
3. 「RAG質問」セクションで質問を入力し、「質問を送信」ボタンをクリックします。
|
| 208 |
+
"""
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
with gr.Tab("RAG質問"):
|
| 212 |
+
gr.Markdown("### RAG質問")
|
| 213 |
+
with gr.Row():
|
| 214 |
+
query_input = gr.Textbox(label="質問を入力してください", placeholder="RAGシステムの主な利点は何ですか?", lines=2)
|
| 215 |
+
submit_button = gr.Button("質問を送信")
|
| 216 |
+
|
| 217 |
+
with gr.Row():
|
| 218 |
+
top_k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="取得文書数 (top_k)")
|
| 219 |
+
similarity_threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="類似度閾値 (0.0-1.0)")
|
| 220 |
+
|
| 221 |
+
llm_output = gr.Textbox(label="LLMの回答", lines=10, interactive=False)
|
| 222 |
+
retrieved_docs_output = gr.Markdown("検索された関連文書", visible=True)
|
| 223 |
+
|
| 224 |
+
with gr.Tab("文書管理"):
|
| 225 |
+
gr.Markdown("### 文書とインデックス管理")
|
| 226 |
+
with gr.Row():
|
| 227 |
+
file_upload_button = gr.File(label="文書JSONファイルをアップロード", file_types=[".json"])
|
| 228 |
+
upload_status = gr.Textbox(label="アップロードステータス", interactive=False)
|
| 229 |
+
|
| 230 |
+
with gr.Row():
|
| 231 |
+
build_index_button = gr.Button("インデックスを構築")
|
| 232 |
+
build_index_status = gr.Textbox(label="インデックス構築ステータス", interactive=False)
|
| 233 |
+
|
| 234 |
+
gr.Markdown("### RAGシステム情報")
|
| 235 |
+
get_stats_button = gr.Button("RAGシステムの状態を表示")
|
| 236 |
+
rag_stats_output = gr.Textbox(label="RAGシステム情報", interactive=False, lines=5)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# イベントハンドラの登録
|
| 240 |
+
submit_button.click(
|
| 241 |
+
fn=generate_response,
|
| 242 |
+
inputs=[query_input, top_k_slider, similarity_threshold_slider],
|
| 243 |
+
outputs=[llm_output, retrieved_docs_output]
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
file_upload_button.upload(
|
| 247 |
+
fn=upload_documents,
|
| 248 |
+
inputs=file_upload_button,
|
| 249 |
+
outputs=[upload_status, retrieved_docs_output] # アップロード結果と関連文書表示エリアを更新
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
build_index_button.click(
|
| 253 |
+
fn=build_index_action,
|
| 254 |
+
inputs=[],
|
| 255 |
+
outputs=[build_index_status, retrieved_docs_output] # 構築結果と関連文書表示エリアを更新
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
get_stats_button.click(
|
| 259 |
+
fn=get_rag_stats_action,
|
| 260 |
+
inputs=[],
|
| 261 |
+
outputs=rag_stats_output
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# システムの初期化をアプリケーション起動時に実行
|
| 265 |
+
# initialize_systems() の中でエラーが発生した場合、rag_system や llm_pipeline が None になる
|
| 266 |
+
# これにより、後続の関数呼び出しでこれらの変数がNoneであることのチェックが機能する
|
| 267 |
+
initialize_systems()
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
demo.launch(debug=True, share=True)
|