import streamlit as st from webui_pages.utils import * from st_aggrid import AgGrid, JsCode from st_aggrid.grid_options_builder import GridOptionsBuilder import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple from configs import (kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) from server.utils import list_embed_models, list_online_embed_models import os import time cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") def config_aggrid( df: pd.DataFrame, columns: Dict[Tuple[str, str], Dict] = {}, selection_mode: Literal["single", "multiple", "disabled"] = "single", use_checkbox: bool = False, ) -> GridOptionsBuilder: gb = GridOptionsBuilder.from_dataframe(df) gb.configure_column("No", width=40) for (col, header), kw in columns.items(): gb.configure_column(col, header, wrapHeaderText=True, **kw) gb.configure_selection( selection_mode=selection_mode, use_checkbox=use_checkbox, pre_selected_rows=st.session_state.get("selected_rows", [0]), ) gb.configure_pagination( enabled=True, paginationAutoPageSize=False, paginationPageSize=10 ) return gb def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]: """ check whether a doc file exists in local knowledge base folder. return the file's name and path if it exists. """ if selected_rows: file_name = selected_rows[0]["file_name"] file_path = get_file_path(kb, file_name) if os.path.isfile(file_path): return file_name, file_path return "", "" def knowledge_base_page(api: ApiRequest, is_lite: bool = None): try: kb_list = {x["kb_name"]: x for x in get_kb_details()} except Exception as e: st.error( "获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") st.stop() kb_names = list(kb_list.keys()) if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names: selected_kb_index = kb_names.index(st.session_state["selected_kb_name"]) else: selected_kb_index = 0 if "selected_kb_info" not in st.session_state: st.session_state["selected_kb_info"] = "" def format_selected_kb(kb_name: str) -> str: if kb := kb_list.get(kb_name): return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})" else: return kb_name selected_kb = st.selectbox( "请选择或新建知识库:", kb_names + ["新建知识库"], format_func=format_selected_kb, index=selected_kb_index ) if selected_kb == "新建知识库": with st.form("新建知识库"): kb_name = st.text_input( "新建知识库名称", placeholder="新知识库名称,不支持中文命名", key="kb_name", ) kb_info = st.text_input( "知识库简介", placeholder="知识库简介,方便Agent查找", key="kb_info", ) cols = st.columns(2) vs_types = list(kbs_config.keys()) vs_type = cols[0].selectbox( "向量库类型", vs_types, index=vs_types.index(DEFAULT_VS_TYPE), key="vs_type", ) if is_lite: embed_models = list_online_embed_models() else: embed_models = list_embed_models() + list_online_embed_models() embed_model = cols[1].selectbox( "Embedding 模型", embed_models, index=embed_models.index(EMBEDDING_MODEL), key="embed_model", ) submit_create_kb = st.form_submit_button( "新建", # disabled=not bool(kb_name), use_container_width=True, ) if submit_create_kb: if not kb_name or not kb_name.strip(): st.error(f"知识库名称不能为空!") elif kb_name in kb_list: st.error(f"名为 {kb_name} 的知识库已经存在!") else: ret = api.create_knowledge_base( knowledge_base_name=kb_name, vector_store_type=vs_type, embed_model=embed_model, ) st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name st.session_state["selected_kb_info"] = kb_info st.rerun() elif selected_kb: kb = selected_kb st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] # 上传文件 files = st.file_uploader("上传知识文件:", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None, help=None, on_change=None, args=None, kwargs=None) if kb_info != st.session_state["selected_kb_info"]: st.session_state["selected_kb_info"] = kb_info api.update_kb_info(kb, kb_info) # with st.sidebar: with st.expander( "文件处理配置", expanded=True, ): cols = st.columns(3) chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE) chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE) cols[2].write("") cols[2].write("") zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE) if st.button( "添加文件到知识库", # use_container_width=True, disabled=len(files) == 0, ): ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True, chunk_size=chunk_size, chunk_overlap=chunk_overlap, zh_title_enhance=zh_title_enhance) if msg := check_success_msg(ret): st.toast(msg, icon="✔") elif msg := check_error_msg(ret): st.toast(msg, icon="✖") st.divider() # 知识库详情 # st.info("请选择文件,点击按钮进行操作。") doc_details = pd.DataFrame(get_kb_file_details(kb)) selected_rows = [] if not len(doc_details): st.info(f"知识库 `{kb}` 中暂无文件") else: st.write(f"知识库 `{kb}` 中已有文件:") st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作") doc_details.drop(columns=["kb_name"], inplace=True) doc_details = doc_details[[ "No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db", ]] doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") gb = config_aggrid( doc_details, { ("No", "序号"): {}, ("file_name", "文档名称"): {}, # ("file_ext", "文档类型"): {}, # ("file_version", "文档版本"): {}, ("document_loader", "文档加载器"): {}, ("docs_count", "文档数量"): {}, ("text_splitter", "分词器"): {}, # ("create_time", "创建时间"): {}, ("in_folder", "源文件"): {"cellRenderer": cell_renderer}, ("in_db", "向量库"): {"cellRenderer": cell_renderer}, }, "multiple", ) doc_grid = AgGrid( doc_details, gb.build(), columns_auto_size_mode="FIT_CONTENTS", theme="alpine", custom_css={ "#gridToolBar": {"display": "none"}, }, allow_unsafe_jscode=True, enable_enterprise_modules=False ) selected_rows = doc_grid.get("selected_rows", []) cols = st.columns(4) file_name, file_path = file_exists(kb, selected_rows) if file_path: with open(file_path, "rb") as fp: cols[0].download_button( "下载选中文档", fp, file_name=file_name, use_container_width=True, ) else: cols[0].download_button( "下载选中文档", "", disabled=True, use_container_width=True, ) st.write() # 将文件分词并加载到向量库中 if cols[1].button( "重新添加至向量库" if selected_rows and ( pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): file_names = [row["file_name"] for row in selected_rows] api.update_kb_docs(kb, file_names=file_names, chunk_size=chunk_size, chunk_overlap=chunk_overlap, zh_title_enhance=zh_title_enhance) st.rerun() # 将文件从向量库中删除,但不删除文件本身。 if cols[2].button( "从向量库删除", disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): file_names = [row["file_name"] for row in selected_rows] api.delete_kb_docs(kb, file_names=file_names) st.rerun() if cols[3].button( "从知识库中删除", type="primary", use_container_width=True, ): file_names = [row["file_name"] for row in selected_rows] api.delete_kb_docs(kb, file_names=file_names, delete_content=True) st.rerun() st.divider() cols = st.columns(3) if cols[0].button( "依据源文件重建向量库", help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", use_container_width=True, type="primary", ): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") for d in api.recreate_vector_store(kb, chunk_size=chunk_size, chunk_overlap=chunk_overlap, zh_title_enhance=zh_title_enhance): if msg := check_error_msg(d): st.toast(msg) else: empty.progress(d["finished"] / d["total"], d["msg"]) st.rerun() if cols[2].button( "删除知识库", use_container_width=True, ): ret = api.delete_knowledge_base(kb) st.toast(ret.get("msg", " ")) time.sleep(1) st.rerun() with st.sidebar: keyword = st.text_input("查询关键字") top_k = st.slider("匹配条数", 1, 100, 3) st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。") docs = [] df = pd.DataFrame([], columns=["seq", "id", "content", "source"]) if selected_rows: file_name = selected_rows[0]["file_name"] docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name) data = [ {"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), "type": x["type"], "metadata": json.dumps(x["metadata"], ensure_ascii=False), "to_del": "", } for i, x in enumerate(docs)] df = pd.DataFrame(data) gb = GridOptionsBuilder.from_dataframe(df) gb.configure_columns(["id", "source", "type", "metadata"], hide=True) gb.configure_column("seq", "No.", width=50) gb.configure_column("page_content", "内容", editable=True, autoHeight=True, wrapText=True, flex=1, cellEditor="agLargeTextCellEditor", cellEditorPopup=True) gb.configure_column("to_del", "删除", editable=True, width=50, wrapHeaderText=True, cellEditor="agCheckboxCellEditor", cellRender="agCheckboxCellRenderer") gb.configure_selection() edit_docs = AgGrid(df, gb.build()) if st.button("保存更改"): origin_docs = { x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs} changed_docs = [] for index, row in edit_docs.data.iterrows(): origin_doc = origin_docs[row["id"]] if row["page_content"] != origin_doc["page_content"]: if row["to_del"] not in ["Y", "y", 1]: changed_docs.append({ "page_content": row["page_content"], "type": row["type"], "metadata": json.loads(row["metadata"]), }) if changed_docs: if api.update_kb_docs(knowledge_base_name=selected_kb, file_names=[file_name], docs={file_name: changed_docs}): st.toast("更新文档成功") else: st.toast("更新文档失败")