File size: 8,442 Bytes
c69a4d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import os
from typing import List, Tuple

from service.rag_service import RAGService


class GradioApp:
    def __init__(self, rag_service: RAGService):
        self.rag_service = rag_service
        self._build_ui()

    def _build_ui(self):
        with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
                       title="Enterprise RAG System") as self.demo:
            gr.Markdown("# 企业级RAG智能问答系统 (Enterprise RAG System)")
            gr.Markdown("您可以**加载现有知识库**快速开始,或**上传新文档**构建一个全新的知识库。")

            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("### 控制面板 (Control Panel)")

                    self.load_kb_button = gr.Button("加载已有知识库 (Load Existing KB)")

                    gr.Markdown("<hr style='border: 1px solid #ddd; margin: 1rem 0;'>")

                    self.file_uploader = gr.File(
                        label="上传新文档以构建 (Upload New Docs to Build)",
                        file_count="multiple",
                        file_types=[".pdf", ".txt"],
                        interactive=True
                    )
                    self.build_kb_button = gr.Button("构建新知识库 (Build New KB)", variant="primary")

                    self.status_box = gr.Textbox(
                        label="系统状态 (System Status)",
                        value="系统已初始化,等待加载或构建知识库。",
                        interactive=False,
                        lines=4
                    )

                # --- 刚开始隐藏,构建了数据库再显示 ---
                with gr.Column(scale=2, visible=False) as self.chat_area:
                    gr.Markdown("### 对话窗口 (Chat Window)")
                    self.chatbot = gr.Chatbot(label="RAG Chatbot", bubble_full_width=False, height=500)
                    self.mode_selector = gr.Radio(
                        ["流式输出(Streaming)","一次性输出(Full)"],
                        label="输出模式:(Output Mode)",
                        value="流式输出(Streaming)"
                    )
                    self.question_box = gr.Textbox(label="您的问题", placeholder="请在此处输入您的问题...",
                                                   show_label=False)
                    with gr.Row():
                        self.submit_btn = gr.Button("提交 (Submit)", variant="primary")
                        self.clear_btn = gr.Button("清空历史 (Clear History)")

                    gr.Markdown("---")
                    self.source_display = gr.Markdown("### 引用来源 (Sources)")

            # --- Event Listeners ---
            self.load_kb_button.click(
                fn=self._handle_load_kb,
                inputs=None,
                outputs=[self.status_box, self.chat_area]
            )

            self.build_kb_button.click(
                fn=self._handle_build_kb,
                inputs=[self.file_uploader],
                outputs=[self.status_box, self.chat_area]
            )

            self.submit_btn.click(
                fn=self._handle_chat_submission,
                inputs=[self.question_box, self.chatbot, self.mode_selector],
                outputs=[self.chatbot, self.question_box, self.source_display]
            )

            self.question_box.submit(
                fn=self._handle_chat_submission,
                inputs=[self.question_box, self.chatbot, self.mode_selector],
                outputs=[self.chatbot, self.question_box, self.source_display]
            )

            self.clear_btn.click(
                fn=self._clear_chat,
                inputs=None,
                outputs=[self.chatbot, self.question_box, self.source_display]
            )

    def _handle_load_kb(self):
        """处理现有知识库的加载。返回更新字典。"""
        success, message = self.rag_service.load_knowledge_base()
        if success:
            return {
                self.status_box: gr.update(value=message),
                self.chat_area: gr.update(visible=True)
            }
        else:
            return {
                self.status_box: gr.update(value=message),
                self.chat_area: gr.update(visible=False)
            }

    def _handle_build_kb(self, files: List[str], progress=gr.Progress(track_tqdm=True)):
        """构建新知识库,返回更新的字典."""
        if not files:
            # --- MODIFIED LINE ---
            return {
                self.status_box: gr.update(value="错误:请至少上传一个文档。"),
                self.chat_area: gr.update(visible=False)
            }

        file_paths = [file.name for file in files]

        try:
            for status in self.rag_service.build_knowledge_base(file_paths):
                progress(0.5, desc=status)

            final_status = "知识库构建完成并已就绪!√"
            # --- MODIFIED LINE ---
            return {
                self.status_box: gr.update(value=final_status),
                self.chat_area: gr.update(visible=True)
            }
        except Exception as e:
            error_message = f"构建失败: {e}"
            # --- MODIFIED LINE ---
            return {
                self.status_box: gr.update(value=error_message),
                self.chat_area: gr.update(visible=False)
            }

    def _handle_chat_submission(self, question: str, history: List[Tuple[str, str]], mode: str):
        if not question or not question.strip():
            yield history, "", "### 引用来源 (Sources)\n"
            return

        history.append((question, ""))

        try:
            # 一次全部输出
            if "Full" in mode:
                yield history, "", "### 引用来源 (Sources)\n"

                answer, sources = self.rag_service.get_response_full(question)
                # 获取引用内容
                context_string_for_display = self.rag_service.get_context_string(sources)
                # 修改格式
                source_text_for_panel = self._format_sources(sources)
                #完整内容:引用+回答
                full_response = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n{answer}"
                history[-1] = (question, full_response)

                yield history, "", source_text_for_panel

            # 流式输出
            else:
                answer_generator, sources = self.rag_service.get_response_stream(question)

                context_string_for_display = self.rag_service.get_context_string(sources)

                source_text_for_panel = self._format_sources(sources)

                yield history, "", source_text_for_panel

                response_prefix = f"{context_string_for_display}\n\n---\n\n**回答 (Answer):**\n"
                history[-1] = (question, response_prefix)
                yield history, "", source_text_for_panel

                answer_log = ""
                for text_chunk in answer_generator:
                    answer_log += text_chunk
                    history[-1] = (question, response_prefix + answer_log)
                    yield history, "", source_text_for_panel

        except Exception as e:
            error_response = f"处理请求时出错: {e}"
            history[-1] = (question, error_response)
            yield history, "", "### 引用来源 (Sources)\n"

    def _format_sources(self, sources: List) -> str:
        source_text = "### 引用来源 (sources)\n)"
        if not sources:
            return source_text

        unique_sources = set()
        for doc in sources:
            source_name = os.path.basename(doc.metadata.get('source', 'Unknown'))
            page_num = doc.metadata.get('page', 'N/A')
            unique_sources.add(f"- **{source_name}** (Page: {page_num})")

        source_text += "\n".join(sorted(list(unique_sources)))
        return source_text

    def _clear_chat(self):
        """清理聊天内容"""
        return None, "", "### 引用来源 (Sources)\n"

    def launch(self):
        self.demo.queue().launch()