| import json |
| import os |
| import gradio as gr |
| from sentence_transformers import SentenceTransformer, util |
| from huggingface_hub import InferenceClient |
|
|
| print("Loading embedding model...") |
| model = SentenceTransformer("BAAI/bge-small-zh-v1.5") |
|
|
| print("Loading FAQs...") |
| with open("faqs.json", "r", encoding="utf-8") as f: |
| faqs = json.load(f) |
|
|
| print(f"Encoding {len(faqs)} FAQ questions...") |
| questions = [item["q"] for item in faqs] |
| faq_embeddings = model.encode(questions, normalize_embeddings=True) |
| print("Ready!") |
|
|
| THRESHOLD = 0.55 |
|
|
| |
| SYSTEM_PROMPT = """你是一个友好、简洁的 AI 学习答疑助手。 |
| 规则: |
| 1. 严格基于"参考资料"回答,不要编造 |
| 2. 资料里没有的内容,直接说"我暂时没这方面的资料" |
| 3. 用自然、口语化的中文,避免生硬复读资料原文 |
| 4. 控制在 3 句话以内""" |
|
|
| USER_PROMPT_TEMPLATE = """【参考资料】 |
| {context} |
| |
| 【用户问题】 |
| {question} |
| |
| 请基于资料用自然语言回答。""" |
|
|
| client = InferenceClient( |
| model="Qwen/Qwen2.5-72B-Instruct", |
| token=os.environ.get("HF_TOKEN"), |
| timeout=20, |
| ) |
|
|
|
|
| def llm_answer(question, top_faqs): |
| if not os.environ.get("HF_TOKEN"): |
| return top_faqs[0]["a"] + "\n\n_(需要在 Space Secrets 设置 HF_TOKEN 以启用 LLM)_" |
| context = "\n\n".join(f"Q: {f['q']}\nA: {f['a']}" for f in top_faqs) |
| user_prompt = USER_PROMPT_TEMPLATE.format(context=context, question=question) |
| try: |
| resp = client.chat_completion( |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| max_tokens=200, |
| temperature=0.3, |
| ) |
| return resp.choices[0].message.content |
| except Exception as e: |
| return top_faqs[0]["a"] + f"\n\n_(LLM 暂不可用:{e})_" |
|
|
|
|
| |
| def chat(query): |
| if not query or not query.strip(): |
| return "请输入您的问题", "", "" |
|
|
| q_emb = model.encode(query, normalize_embeddings=True) |
| scores = util.cos_sim(q_emb, faq_embeddings)[0] |
| top_idx = scores.argsort(descending=True)[:3].tolist() |
| top1 = top_idx[0] |
| top1_score = float(scores[top1]) |
|
|
| if top1_score < THRESHOLD: |
| reply = "抱歉,我暂时无法理解您的问题。建议换个说法,或查看下方相关问题。" |
| else: |
| top3_faqs = [faqs[i] for i in top_idx] |
| reply = llm_answer(query, top3_faqs) |
|
|
| info = f"**类别**: {faqs[top1]['category']} | **匹配度**: {top1_score:.2f} | **匹配的问题**: {faqs[top1]['q']}" |
| related = "### 您可能也想问:\n" |
| for i in top_idx[1:]: |
| related += f"- {faqs[i]['q']} _(相似度 {float(scores[i]):.2f})_\n" |
| return reply, info, related |
|
|
|
|
| examples = [ |
| "embedding 是什么意思?", |
| "中文应该用哪个向量模型?", |
| "BERT 和 GPT 有什么不一样?", |
| "pipeline 是干什么用的?", |
| "AI 怎么知道两句话意思一样?", |
| "怎么把模型跑到 GPU 上?", |
| "为什么 LLM 会胡说八道?", |
| "今天天气怎么样?", |
| ] |
|
|
| iface = gr.Interface( |
| fn=chat, |
| inputs=gr.Textbox(label="您的问题", placeholder="例如:embedding 是什么?", lines=2), |
| outputs=[ |
| gr.Markdown(label="答案"), |
| gr.Markdown(label="检索详情"), |
| gr.Markdown(label="相关问题"), |
| ], |
| title="🤖 AI 学习 FAQ 机器人(RAG)", |
| description="基于 BAAI/bge-small-zh-v1.5 检索 + Qwen2.5-7B-Instruct 生成 · 30 条 AI 学习 FAQ", |
| examples=examples, |
| flagging_mode="never", |
| theme="soft", |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|