|
|
|
|
|
|
|
|
|
|
|
from openai import OpenAI |
|
import gradio as gr |
|
import csv |
|
from datetime import datetime |
|
import torch |
|
from langchain.vectorstores import FAISS |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepseek_key = os.getenv("DEEPSEEK_KEY") |
|
if not deepseek_key: |
|
print("Deepseek key not found. Please set it in HF Space Secrets.") |
|
else: |
|
print("Deepseek key loaded successfully!") |
|
|
|
client = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="shibing624/text2vec-base-chinese", |
|
model_kwargs={"device": device} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def load_vectorstore(index_path: str, embed_obj) -> FAISS: |
|
"""加载 FAISS 索引""" |
|
return FAISS.load_local( |
|
index_path, |
|
embed_obj, |
|
allow_dangerous_deserialization=True |
|
) |
|
|
|
def build_prompt_for_chatgpt(query, cs_docs, hc_docs, ic_docs): |
|
""" |
|
根据检索到的三类文档(CS/HC/IC),构造对大模型的提示 Prompt。 |
|
""" |
|
|
|
cs_context = [] |
|
for i, doc in enumerate(cs_docs, start=1): |
|
cs_context.append(f"[CS Doc {i}]\n{doc.page_content}\n") |
|
|
|
|
|
hc_context = [] |
|
for i, doc in enumerate(hc_docs, start=1): |
|
link = doc.metadata.get("链接", "无链接") |
|
hc_context.append(f"[HC Doc {i}]\n链接: {link}\n{doc.page_content}\n") |
|
|
|
|
|
ic_context = [] |
|
for i, doc in enumerate(ic_docs, start=1): |
|
link = doc.metadata.get("链接", "无链接") |
|
ic_context.append(f"[IC Doc {i}]\n链接: {link}\n{doc.page_content}\n") |
|
|
|
prompt = f"""\ |
|
请根据以下文档,回答用户的问题。如果无法从文档中找到答案,请说明你无法回答。 |
|
|
|
用户问题:{query} |
|
|
|
====================== |
|
【Customer Success - Top {len(cs_docs)}】 |
|
{''.join(cs_context)} |
|
|
|
【Help Center - Top {len(hc_docs)}】 |
|
{''.join(hc_context)} |
|
|
|
【IC - Top {len(ic_docs)}】 |
|
{''.join(ic_context)} |
|
====================== |
|
|
|
请根据helpcenter链接查询help center内容给出准确答案 |
|
|
|
请给出简洁、准确、机构清晰且包含必要细节的回答,并尽量引用help center链接,但不需要表明来自哪个文档: |
|
""" |
|
return prompt |
|
|
|
def combined_search(query, cs_index_path, hc_index_path, ic_index_path, |
|
cs_k=3, hc_k=2, ic_k=3): |
|
""" |
|
- 分别加载 CS、HC、IC 的向量索引 |
|
- 分别检索 query |
|
- 将检索到的文档传递给 build_prompt_for_chatgpt,返回 Prompt |
|
""" |
|
cs_vectorstore = load_vectorstore(cs_index_path, embeddings) |
|
hc_vectorstore = load_vectorstore(hc_index_path, embeddings) |
|
ic_vectorstore = load_vectorstore(ic_index_path, embeddings) |
|
|
|
cs_docs = cs_vectorstore.similarity_search(query, k=cs_k) |
|
hc_docs = hc_vectorstore.similarity_search(query, k=hc_k) |
|
ic_docs = ic_vectorstore.similarity_search(query, k=ic_k) |
|
|
|
prompt = build_prompt_for_chatgpt(query, cs_docs, hc_docs, ic_docs) |
|
return cs_docs, hc_docs, ic_docs, prompt |
|
|
|
|
|
|
|
|
|
|
|
def generate_answer_with_deepseek(prompt_text): |
|
""" |
|
用 DeepSeek 的 “deepseek-reasoner” 模型生成答案。 |
|
你给出的示例是通过 `client.chat.completions.create()`. |
|
""" |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant"}, |
|
{"role": "user", "content": prompt_text}, |
|
] |
|
|
|
response = client.chat.completions.create( |
|
model="deepseek-reasoner", |
|
messages=messages, |
|
stream=False |
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|
|
|
|
|
|
|
|
|
def run_search_and_answer(user_query, store_name): |
|
""" |
|
1) 检索出文档(cs_docs / hc_docs / ic_docs) |
|
2) 生成 Prompt |
|
3) 调用 DeepSeek 模型生成答案 |
|
4) 返回给 Gradio 界面 |
|
""" |
|
cs_index_path = "CS_faiss_index" |
|
hc_index_path = "HC_faiss_index" |
|
ic_index_path = "IC_faiss_index" |
|
|
|
|
|
cs_docs, hc_docs, ic_docs, prompt_text = combined_search( |
|
query=user_query, |
|
cs_index_path=cs_index_path, |
|
hc_index_path=hc_index_path, |
|
ic_index_path=ic_index_path, |
|
cs_k=5, |
|
hc_k=2, |
|
ic_k=5 |
|
) |
|
|
|
|
|
deepseek_answer = generate_answer_with_deepseek(prompt_text) |
|
|
|
|
|
cs_result = "\n".join([ |
|
f"{doc.page_content[:60]}... (Link: {doc.metadata.get('link', '无链接')})" |
|
for i, doc in enumerate(cs_docs, start=1) |
|
]) |
|
|
|
hc_result_list = [] |
|
hc_links = [] |
|
for i, doc in enumerate(hc_docs, start=1): |
|
snippet = doc.page_content[:60] |
|
link = doc.metadata.get('链接', '无链接') |
|
hc_result_list.append(f"{snippet}... (Link: {link})") |
|
hc_links.append(link) |
|
hc_result = "\n".join(hc_result_list) |
|
|
|
ic_result = "\n".join([ |
|
f"{doc.page_content[:60]}... (Link: {doc.metadata.get('链接', '无链接')})" |
|
for i, doc in enumerate(ic_docs, start=1) |
|
]) |
|
|
|
doc_summary = ( |
|
"=== Customer Success Docs ===\n" |
|
f"{cs_result}\n\n" |
|
"=== Help Center Docs ===\n" |
|
f"{hc_result}\n\n" |
|
"=== IC Docs ===\n" |
|
f"{ic_result}\n" |
|
) |
|
|
|
|
|
links_str = "; ".join(hc_links) |
|
|
|
|
|
return doc_summary, prompt_text, deepseek_answer, links_str |
|
|
|
from huggingface_hub import HfApi |
|
|
|
def record_feedback(user_query, final_answer, feedback_choice, improved_answer, store_name, links_str): |
|
|
|
|
|
|
|
with open("feedback.csv", "a", encoding="utf-8", newline="") as f: |
|
writer = csv.writer(f) |
|
current_time = datetime.now().strftime("%Y-%m-%d %H:00:00") |
|
|
|
writer.writerow([ |
|
current_time, |
|
store_name, |
|
user_query, |
|
feedback_choice, |
|
final_answer if feedback_choice == "好" else improved_answer, |
|
links_str |
|
]) |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN", None) |
|
if not hf_token: |
|
return "缺少 HF_TOKEN,无法推送到 Hugging Face。" |
|
|
|
|
|
api = HfApi() |
|
try: |
|
api.upload_file( |
|
path_or_fileobj="feedback.csv", |
|
path_in_repo="feedback.csv", |
|
repo_id="PebllaRyan/Feedbacks", |
|
repo_type="dataset", |
|
token=hf_token, |
|
commit_message="Update feedback logs" |
|
) |
|
return "已记录到本地 CSV,并成功推送到 PebllaRyan/Feedbacks!" |
|
except Exception as e: |
|
return f"本地记录成功,但推送到仓库失败: {e}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Peblla 智能知识库助手 - 整合DeepSeek示例") |
|
|
|
|
|
store_name_box = gr.Textbox(label="店铺名称", placeholder="可选:你在哪个店铺遇到了问题?") |
|
|
|
|
|
user_query_box = gr.Textbox(label="问题", lines=2) |
|
|
|
|
|
btn_search = gr.Button("提交问题") |
|
|
|
|
|
doc_result_box = gr.Textbox(label="检索到的文档(简要)", interactive=False, lines=6) |
|
prompt_box = gr.Textbox(label="生成的 Prompt", interactive=False, lines=6, visible=False) |
|
deepseek_answer_box = gr.Textbox(label="DeepSeek 回答", interactive=False, lines=6) |
|
|
|
|
|
hc_links_box = gr.Textbox(visible=False) |
|
|
|
gr.Markdown("---") |
|
|
|
|
|
feedback_choice = gr.Radio( |
|
choices=["好", "不好"], |
|
label="回答质量如何?", |
|
value=None |
|
) |
|
|
|
improved_answer_box = gr.Textbox( |
|
label="如果选择“不好”,请在这里输入改进后的答案", |
|
lines=5 |
|
) |
|
|
|
btn_feedback = gr.Button("提交反馈") |
|
feedback_result = gr.Markdown() |
|
|
|
|
|
btn_search.click( |
|
fn=run_search_and_answer, |
|
inputs=[user_query_box, store_name_box], |
|
outputs=[doc_result_box, prompt_box, deepseek_answer_box, hc_links_box] |
|
) |
|
|
|
|
|
btn_feedback.click( |
|
fn=record_feedback, |
|
inputs=[ |
|
user_query_box, |
|
deepseek_answer_box, |
|
feedback_choice, |
|
improved_answer_box, |
|
store_name_box, |
|
hc_links_box |
|
], |
|
outputs=[feedback_result] |
|
) |
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |
|
|