Spaces:
Sleeping
Sleeping
功能优化: 添加双栏pdf识别选项到页面,并优化config文件中关于文档解析的设置
Browse files- ChuanhuChatbot.py +6 -0
- config_example.json +5 -2
- modules/chat_func.py +8 -6
- modules/config.py +12 -2
- modules/llama_func.py +8 -6
ChuanhuChatbot.py
CHANGED
@@ -78,6 +78,10 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
78 |
value=REPLY_LANGUAGES[0],
|
79 |
)
|
80 |
index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
|
|
|
|
|
|
|
|
|
81 |
|
82 |
with gr.Tab(label="Prompt"):
|
83 |
systemPromptTxt = gr.Textbox(
|
@@ -295,6 +299,8 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
295 |
)
|
296 |
reduceTokenBtn.click(**get_usage_args)
|
297 |
|
|
|
|
|
298 |
# ChatGPT
|
299 |
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
300 |
keyTxt.submit(**get_usage_args)
|
|
|
78 |
value=REPLY_LANGUAGES[0],
|
79 |
)
|
80 |
index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
|
81 |
+
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
82 |
+
# TODO: 公式ocr
|
83 |
+
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
84 |
+
updateDocConfigBtn = gr.Button("更新解析文件参数")
|
85 |
|
86 |
with gr.Tab(label="Prompt"):
|
87 |
systemPromptTxt = gr.Textbox(
|
|
|
299 |
)
|
300 |
reduceTokenBtn.click(**get_usage_args)
|
301 |
|
302 |
+
updateDocConfigBtn.click(update_doc_config, [two_column], None)
|
303 |
+
|
304 |
# ChatGPT
|
305 |
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
306 |
keyTxt.submit(**get_usage_args)
|
config_example.json
CHANGED
@@ -2,8 +2,11 @@
|
|
2 |
"openai_api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxx",
|
3 |
"https_proxy": "http://127.0.0.1:1079",
|
4 |
"http_proxy": "http://127.0.0.1:1079",
|
5 |
-
"
|
6 |
-
"
|
|
|
|
|
|
|
7 |
},
|
8 |
"users": [
|
9 |
["root", "root"]
|
|
|
2 |
"openai_api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxx",
|
3 |
"https_proxy": "http://127.0.0.1:1079",
|
4 |
"http_proxy": "http://127.0.0.1:1079",
|
5 |
+
"advance_docs": {
|
6 |
+
"pdf": {
|
7 |
+
"two_column": true,
|
8 |
+
"formula_ocr": true
|
9 |
+
}
|
10 |
},
|
11 |
"users": [
|
12 |
["root", "root"]
|
modules/chat_func.py
CHANGED
@@ -291,12 +291,14 @@ def predict(
|
|
291 |
msg = "索引构建完成,获取回答中……"
|
292 |
logging.info(msg)
|
293 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
300 |
reference_results = [n.node.text for n in nodes]
|
301 |
reference_results = add_source_numbers(reference_results, use_source=False)
|
302 |
display_reference = add_details(reference_results)
|
|
|
291 |
msg = "索引构建完成,获取回答中……"
|
292 |
logging.info(msg)
|
293 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
294 |
+
with retrieve_proxy():
|
295 |
+
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
|
296 |
+
prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
|
297 |
+
from llama_index import ServiceContext
|
298 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
299 |
+
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
|
300 |
+
query_bundle = QueryBundle(inputs)
|
301 |
+
nodes = query_object.retrieve(query_bundle)
|
302 |
reference_results = [n.node.text for n in nodes]
|
303 |
reference_results = add_source_numbers(reference_results, use_source=False)
|
304 |
display_reference = add_details(reference_results)
|
modules/config.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from contextlib import contextmanager
|
2 |
import os
|
3 |
import logging
|
@@ -11,6 +12,8 @@ __all__ = [
|
|
11 |
"dockerflag",
|
12 |
"retrieve_proxy",
|
13 |
"log_level",
|
|
|
|
|
14 |
]
|
15 |
|
16 |
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
@@ -109,5 +112,12 @@ def retrieve_proxy(proxy=None):
|
|
109 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
110 |
|
111 |
|
112 |
-
## 处理advance
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
from contextlib import contextmanager
|
3 |
import os
|
4 |
import logging
|
|
|
12 |
"dockerflag",
|
13 |
"retrieve_proxy",
|
14 |
"log_level",
|
15 |
+
"advance_docs",
|
16 |
+
"update_doc_config"
|
17 |
]
|
18 |
|
19 |
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
|
|
112 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
113 |
|
114 |
|
115 |
+
## 处理advance docs
|
116 |
+
advance_docs = defaultdict(lambda: defaultdict(dict))
|
117 |
+
advance_docs.update(config.get("advance_docs", {}))
|
118 |
+
def update_doc_config(two_column_pdf):
|
119 |
+
global advance_docs
|
120 |
+
if two_column_pdf:
|
121 |
+
advance_docs["pdf"]["two_column"] = True
|
122 |
+
|
123 |
+
logging.info(f"更新后的文件参数为:{advance_docs}")
|
modules/llama_func.py
CHANGED
@@ -45,8 +45,9 @@ def get_documents(file_src):
|
|
45 |
logging.debug("Loading PDF...")
|
46 |
try:
|
47 |
from modules.pdf_func import parse_pdf
|
48 |
-
from modules.config import
|
49 |
-
|
|
|
50 |
except:
|
51 |
pdftext = ""
|
52 |
with open(file.name, 'rb') as pdfFileObj:
|
@@ -106,10 +107,11 @@ def construct_index(
|
|
106 |
try:
|
107 |
documents = get_documents(file_src)
|
108 |
logging.info("构建索引中……")
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
113 |
logging.debug("索引构建完成!")
|
114 |
os.makedirs("./index", exist_ok=True)
|
115 |
index.save_to_disk(f"./index/{index_name}.json")
|
|
|
45 |
logging.debug("Loading PDF...")
|
46 |
try:
|
47 |
from modules.pdf_func import parse_pdf
|
48 |
+
from modules.config import advance_docs
|
49 |
+
two_column = advance_docs["pdf"].get("two_column", False)
|
50 |
+
pdftext = parse_pdf(file.name, two_column).text
|
51 |
except:
|
52 |
pdftext = ""
|
53 |
with open(file.name, 'rb') as pdfFileObj:
|
|
|
107 |
try:
|
108 |
documents = get_documents(file_src)
|
109 |
logging.info("构建索引中……")
|
110 |
+
with retrieve_proxy():
|
111 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
|
112 |
+
index = GPTSimpleVectorIndex.from_documents(
|
113 |
+
documents, service_context=service_context
|
114 |
+
)
|
115 |
logging.debug("索引构建完成!")
|
116 |
os.makedirs("./index", exist_ok=True)
|
117 |
index.save_to_disk(f"./index/{index_name}.json")
|