File size: 8,442 Bytes
c69a4d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import gradio as gr
import os
from typing import List, Tuple
from service.rag_service import RAGService
class GradioApp:
def __init__(self, rag_service: RAGService):
self.rag_service = rag_service
self._build_ui()
def _build_ui(self):
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
title="Enterprise RAG System") as self.demo:
gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)")
gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 控制面板 (Control Panel)")
self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)")
gr.Markdown("<hr style='border: 1px solid #ddd; margin: 1rem 0;'>")
self.file_uploader = gr.File(
label="上传新文档以构建 (Upload New Docs to Build)",
file_count="multiple",
file_types=[".pdf", ".txt"],
interactive=True
)
self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary")
self.status_box = gr.Textbox(
label="系统状态 (System Status)",
value="系统已初始化,等待加载或构建知识库。",
interactive=False,
lines=4
)
# --- 刚开始隐藏,构建了数据库再显示 ---
with gr.Column(scale=2, visible=False) as self.chat_area:
gr.Markdown("### 对话窗口 (Chat Window)")
self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500)
self.mode_selector = gr.Radio(
["流式输出(Streaming)","一次性输出(Full)"],
label="输出模式:(Output Mode)",
value="流式输出(Streaming)"
)
self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...",
show_label=False)
with gr.Row():
self.submit_btn = gr.Button("提交 (Submit)", variant="primary")
self.clear_btn = gr.Button("清空历史 (Clear History)")
gr.Markdown("---")
self.source_display = gr.Markdown("### 引用来源 (Sources)")
# --- Event Listeners ---
self.load_kb_button.click(
fn=self._handle_load_kb,
inputs=None,
outputs=[self.status_box, self.chat_area]
)
self.build_kb_button.click(
fn=self._handle_build_kb,
inputs=[self.file_uploader],
outputs=[self.status_box, self.chat_area]
)
self.submit_btn.click(
fn=self._handle_chat_submission,
inputs=[self.question_box, self.chatbot, self.mode_selector],
outputs=[self.chatbot, self.question_box, self.source_display]
)
self.question_box.submit(
fn=self._handle_chat_submission,
inputs=[self.question_box, self.chatbot, self.mode_selector],
outputs=[self.chatbot, self.question_box, self.source_display]
)
self.clear_btn.click(
fn=self._clear_chat,
inputs=None,
outputs=[self.chatbot, self.question_box, self.source_display]
)
def _handle_load_kb(self):
"""处理现有知识库的加载。返回更新字典。"""
success, message = self.rag_service.load_knowledge_base()
if success:
return {
self.status_box: gr.update(value=message),
self.chat_area: gr.update(visible=True)
}
else:
return {
self.status_box: gr.update(value=message),
self.chat_area: gr.update(visible=False)
}
def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)):
"""构建新知识库,返回更新的字典."""
if not files:
# --- MODIFIED LINE ---
return {
self.status_box: gr.update(value="错误:请至少上传一个文档。"),
self.chat_area: gr.update(visible=False)
}
file_paths = [file.name for file in files]
try:
for status in self.rag_service.build_knowledge_base(file_paths):
progress(0.5, desc=status)
final_status = "知识库构建完成并已就绪!√"
# --- MODIFIED LINE ---
return {
self.status_box: gr.update(value=final_status),
self.chat_area: gr.update(visible=True)
}
except Exception as e:
error_message = f"构建失败: {e}"
# --- MODIFIED LINE ---
return {
self.status_box: gr.update(value=error_message),
self.chat_area: gr.update(visible=False)
}
def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str):
if not question or not question.strip():
yield history, "", "### 引用来源 (Sources)\n"
return
history.append((question, ""))
try:
# 一次全部输出
if "Full" in mode:
yield history, "", "### 引用来源 (Sources)\n"
answer, sources = self.rag_service.get_response_full(question)
# 获取引用内容
context_string_for_display = self.rag_service.get_context_string(sources)
# 修改格式
source_text_for_panel = self._format_sources(sources)
#完整内容:引用+回答
full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}"
history[-1] = (question, full_response)
yield history, "", source_text_for_panel
# 流式输出
else:
answer_generator, sources = self.rag_service.get_response_stream(question)
context_string_for_display = self.rag_service.get_context_string(sources)
source_text_for_panel = self._format_sources(sources)
yield history, "", source_text_for_panel
response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n"
history[-1] = (question, response_prefix)
yield history, "", source_text_for_panel
answer_log = ""
for text_chunk in answer_generator:
answer_log += text_chunk
history[-1] = (question, response_prefix + answer_log)
yield history, "", source_text_for_panel
except Exception as e:
error_response = f"处理请求时出错: {e}"
history[-1] = (question, error_response)
yield history, "", "### 引用来源 (Sources)\n"
def _format_sources(self, sources: List) -> str:
source_text = "### 引用来源 (sources)\n)"
if not sources:
return source_text
unique_sources = set()
for doc in sources:
source_name = os.path.basename(doc.metadata.get('source', 'Unknown'))
page_num = doc.metadata.get('page', 'N/A')
unique_sources.add(f"- **{source_name}** (Page: {page_num})")
source_text += "\n".join(sorted(list(unique_sources)))
return source_text
def _clear_chat(self):
"""清理聊天内容"""
return None, "", "### 引用来源 (Sources)\n"
def launch(self):
self.demo.queue().launch()
|