#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版) 整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。 此版本專注於效能、可維護性、健壯性與使用者體驗。 已修改以支持多模型比較,包括有/無RAG的實驗。 [MODIFIED] 根據參考指南,優化簡易回答 (concise) 部分的 Prompt Template: - 移除字數限制,允許更完整的資訊呈現。 - 語言表達:使用生活化、白話說法,避免醫學專有名詞,或加上簡單解釋。 - 結構與重點排序:先講最重要的 (怎麼吃、吃多少、多久吃一次),再補充注意事項、副作用、保存方式。使用清楚標題與分段,如「劑量」、「怎麼吃」、「需要注意的事」。 - 視覺設計:使用清單或條列式呈現 (適合 LINE 文字格式),重要警語用粗體模擬 (如全大寫或標記)。 - 情境化指引:結合實際生活場景,例如「飯後30分鐘內吃」而非「餐後服用」;「如果忘記吃,馬上補但別重複」。 - 使用者考量:假設一般成人使用者,語言簡單;若需多語或圖像化,可未來擴展。 - 補充互動性:結尾可建議 FAQ,但保持簡易。 核心原則:資訊完整、易懂、順序清楚,無字數限制但保持精簡。 """ # ---------- 環境與快取設定 (應置於最前) ---------- import os import pathlib os.environ.setdefault("HF_HOME", "/tmp/hf") os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")): pathlib.Path(d).mkdir(parents=True, exist_ok=True) # ---------- Python 標準函式庫 ---------- import re import hmac import base64 import hashlib import pickle import logging import json import textwrap import time import tenacity import unicodedata from typing import List, Dict, Any, Optional from functools import lru_cache from dataclasses import dataclass, field from contextlib import asynccontextmanager from collections import defaultdict # ---------- 第三方函式庫 ---------- import numpy as np import pandas as pd from fastapi import FastAPI, Request, HTTPException, status import uvicorn import jieba from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer import faiss import torch import openai from openai import OpenAI import requests # ==== CONFIG (從環境變數載入,或使用預設值) ==== def _require_env(var: str) -> str: v = os.getenv(var) if not v: raise RuntimeError(f"FATAL: Missing required environment variable: {var}") return v CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv") FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index") SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl") BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl") TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20)) PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30)) EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh") # [MODIFIED] 推薦適合模型 AVAILABLE_MODELS = [ "azure-gpt-4", "azure-gpt-4.1", "azure-gpt-4.1-mini", "azure-gpt-4o" ] LLM_API_CONFIG = { "base_url": "https://litellm-ekkks8gsocw.dgx-coolify.apmic.ai/", "api_key": "sk-eT_04m428oAPUD5kUmIhVA", } LLM_MODEL_CONFIG = { "max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)), "max_tokens": int(os.getenv("MAX_TOKENS", 1024)), "temperature": float(os.getenv("TEMPERATURE", 0.0)), "seed": int(os.getenv("LLM_SEED", 42)), } INTENT_CATEGORIES = [ "操作 (Administration)", "保存/攜帶 (Storage & Handling)", "副作用/異常 (Side Effects / Issues)", "劑型相關 (Dosage Form Concerns)", "時間/併用 (Timing & Interaction)", "劑量調整 (Dosage Adjustment)", "禁忌症/適應症 (Contraindications/Indications)" ] INTENT_TO_SECTION = { "操作 (Administration)": ["用法用量", "病人使用須知"], "保存/攜帶 (Storage & Handling)": ["包裝及儲存"], "副作用/異常 (Side Effects / Issues)": ["不良反應", "警語與注意事項"], "劑型相關 (Dosage Form Concerns)": ["劑型", "藥品外觀"], "時間/併用 (Timing & Interaction)": ["用法用量"], "劑量調整 (Dosage Adjustment)": ["用法用量"], "禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"] } REFERENCE_MAPPING = { "如何用藥?": "病人使用須知、用法用量", "如何保存與攜帶?": "包裝及儲存", "可能的副作用?": "警語與注意事項、不良反應", "每次劑量多少?": "用法用量", "用藥時間?": "用法用量", } SECTION_TO_INTENT = defaultdict(list) for intent, sections in INTENT_TO_SECTION.items(): for section in sections: SECTION_TO_INTENT[section].append(intent) DRUG_NAME_MAPPING = { "fentanyl patch": "fentanyl", "spiriva respimat": "spiriva", "augmentin for syrup": "augmentin syrup", "nitrostat": "nitroglycerin", "ozempic": "ozempic", "niflec": "niflec", "fosamax": "fosamax", "humira": "humira", "premarin": "premarin", "smecta": "smecta", } DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。" PROMPT_TEMPLATES = { "analyze_query": """ 請分析以下使用者問題,並完成以下兩個任務: 1. 將問題分解為1-3個核心的子問題。 2. 從清單中選擇所有相關的意圖分類。 請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intents' (字串陣列) 兩個鍵。 範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}} 意圖分類清單: {options} 使用者問題:{query} """, "expand_query": """ 請根據以下意圖:{intents},擴展這個查詢,加入相關同義詞或術語。 原始查詢:{query} 請僅輸出擴展後的查詢,不需任何額外的解釋或格式。 """, "final_answer_concise": """ 您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆: - 嚴格依據資料:所有內容完全來自參考資料,禁止捏造。 - 資料不足:直接回覆「根據提供的資料,無法回答您的問題。」 - 語言:生活化、白話,避免專有名詞或加解釋。 - 結構:先重點 (怎麼吃、吃多少、多久一次),再注意事項、副作用、保存。使用標題分段。 - 視覺:條列式呈現,重要警語用全大寫。 - 情境化:生活場景指引,如「飯後30分鐘內吃」。 - 精簡:資訊完整、易讀。 - 結尾:加「如有不適請立即就醫。」 {additional_instruction} --- 參考資料: {context} --- 使用者問題:{query} 請直接輸出最終的答案: """, "direct_answer_concise": """ 您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆: - 語言:生活化、白話,避免專有名詞或加解釋。 - 結構:先重點 (怎麼吃、吃多少、多久一次),再注意事項、副作用、保存。使用標題分段。 - 視覺:條列式呈現,重要警語用全大寫。 - 情境化:生活場景指引,如「飯後30分鐘內吃」。 - 精簡:資訊完整、易讀。 - 結尾:加「如有不適請立即就醫。」 {additional_instruction} 使用者問題:{query} 請直接輸出最終的答案: """ } # ---------- 日誌設定 ---------- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) # 統一字串正規化函式 def _norm(s: str) -> str: s = ''.join(c for c in unicodedata.normalize("NFKC", s) if not unicodedata.category(c).startswith('P')) return re.sub(r"\s+", "", s.lower()).strip() @dataclass class FusedCandidate: idx: int fused_score: float sem_score: float bm_score: float @dataclass class RerankResult: idx: int rerank_score: float text: str meta: Dict[str, Any] = field(default_factory=dict) # ---------- 核心 RAG 邏輯 ---------- class RagPipeline: def __init__(self): if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]: raise ValueError("LLM API Key or Base URL is not configured.") self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"]) self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding") self.drug_name_to_ids: Dict[str, List[str]] = {} self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()} self.state = type('state', (), {})() self.CLARIFICATION_PROMPT = """ 請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精確地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境等。 範例: 使用者問題:這個藥會怎麼樣? 澄清提問:您好,請問您指的是哪一種藥物呢? 使用者問題:請問這要吃多久? 澄清提問:請問您是想了解該藥品的建議療程長度嗎? 使用者問題:{query} 澄清提問:""" def _load_model(self, model_class, model_name: str, model_type: str): device = "cuda" if torch.cuda.is_available() else "cpu" log.info(f"載入 {model_type} 模型:{model_name} 至 {device}...") try: return model_class(model_name, device=device) except Exception as e: log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。") return model_class(model_name, device="cpu") def load_data(self): log.info("開始載入資料與模型...") for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]: if not pathlib.Path(path).exists(): raise FileNotFoundError(f"必要的資料檔案不存在: {path}") self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('') self.drug_name_to_ids = self._build_drug_name_to_ids() self._load_drug_name_vocabulary() self.state.index = faiss.read_index(FAISS_INDEX) if hasattr(self.state.index, "nprobe"): self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16")) with open(SENTENCES_PKL, "rb") as f: data = pickle.load(f) self.state.sentences = data["sentences"] self.state.meta = data["meta"] with open(BM25_PKL, "rb") as f: bm25_data = pickle.load(f) self.state.bm25 = bm25_data["bm25"] log.info("所有模型與資料載入完成。") def _build_drug_name_to_ids(self) -> Dict[str, List[str]]: mapping = {} for _, row in self.df_csv.iterrows(): drug_id = row['drug_id'] zh_parts = list(jieba.cut(row['drug_name_zh'])) en_parts = re.findall(r'[a-zA-Z0-9]+', row.get('drug_name_en', '').lower()) norm_parts = re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(row['drug_name_norm'])) all_parts = set(zh_parts + en_parts + norm_parts) for part in all_parts: part = part.strip() if part and len(part) > 1: mapping.setdefault(part, set()).add(drug_id) for alias, canonical_name in DRUG_NAME_MAPPING.items(): if _norm(canonical_name) in _norm(row['drug_name_norm']): mapping.setdefault(_norm(alias), set()).add(drug_id) for key in mapping: mapping[key] = sorted(list(mapping[key])) return mapping def _load_drug_name_vocabulary(self): for _, row in self.df_csv.iterrows(): norm_name = row['drug_name_norm'] words = re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', norm_name) for word in words: if re.search(r'[\u4e00-\u9fff]', word): self.drug_vocab["zh"].add(word) else: self.drug_vocab["en"].add(word) for alias in DRUG_NAME_MAPPING: if re.search(r'[\u4e00-\u9fff]', alias): self.drug_vocab["zh"].add(alias) else: self.drug_vocab["en"].add(alias) for word in self.drug_vocab["zh"]: jieba.add_word(word, freq=2_000_000) @tenacity.retry(wait=tenacity.wait_fixed(2), stop=tenacity.stop_after_attempt(3)) def _llm_call(self, model_name: str, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str: log.info(f"{model_name} 呼叫開始.") start_time = time.time() response = self.llm_client.chat.completions.create( model=model_name, messages=messages, max_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens"], temperature=temperature if temperature is not None else LLM_MODEL_CONFIG["temperature"], seed=seed or LLM_MODEL_CONFIG["seed"], ) content = "" if response.choices and response.choices[0].message: content = response.choices[0].message.content or "" log.info(f"{model_name} 呼叫完成,耗時: {time.time() - start_time:.2f} 秒。內容長度: {len(content)}.") return content def answer_question(self, q_orig: str) -> Dict[str, Dict[str, Dict[str, str]]]: start_time = time.time() results = {} drug_ids = self._find_drug_ids_from_name(q_orig) if not drug_ids: clarifications = {} for model in AVAILABLE_MODELS: clarification = self._generate_clarification_query(model, q_orig) clarifications[model] = {"clarification": clarification + f"\n{DISCLAIMER}"} return clarifications analyses = {model: self._analyze_query(model, q_orig) for model in AVAILABLE_MODELS} all_candidates = {} for model in AVAILABLE_MODELS: sub_queries = analyses[model].get("sub_queries", [q_orig]) intents = analyses[model].get("intents", []) all_candidates[model] = self._retrieve_candidates_for_all_queries(model, drug_ids, sub_queries, intents) for model in AVAILABLE_MODELS: final_candidates = all_candidates[model][:TOP_K_SENTENCES] reranked_results = [RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) for c in final_candidates] prioritized_results = self._prioritize_context(reranked_results, analyses[model]["intents"]) context = self._build_context(prioritized_results) if not context: results[model] = {"error": f"根據提供的資料,無法回答您的問題。\n{DISCLAIMER}"} continue with_rag = self._generate_answers(model, q_orig, context, analyses[model]["intents"], "final_answer") without_rag = self._generate_answers(model, q_orig, "", analyses[model]["intents"], "direct_answer") results[model] = {"with_rag": with_rag, "without_rag": without_rag} log.info(f"查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒") return results def _analyze_query(self, model_name: str, query: str) -> Dict[str, Any]: norm_query = _norm(query) for ref_key, ref_value in REFERENCE_MAPPING.items(): if _norm(ref_key) in norm_query: sections = [s.strip() for s in ref_value.split(',')] intents = list(set([intent for section in sections for intent in SECTION_TO_INTENT.get(section, [])])) if intents: return {"sub_queries": [query], "intents": intents} prompt = PROMPT_TEMPLATES["analyze_query"].format(options="\n".join(f"- {c}" for c in INTENT_CATEGORIES), query=query) response_str = self._llm_call(model_name, [{"role": "user", "content": prompt}]) return self._safe_json_parse(response_str, {"sub_queries": [query], "intents": []}) def _generate_clarification_query(self, model_name: str, query: str) -> str: prompt = self.CLARIFICATION_PROMPT.format(query=query) return self._llm_call(model_name, [{"role": "user", "content": prompt}]).strip() def _find_drug_ids_from_name(self, query: str) -> List[str]: q_norm_parts = set(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(query))) drug_ids = set() for part in q_norm_parts: if part in self.drug_name_to_ids: drug_ids.update(self.drug_name_to_ids[part]) return sorted(list(drug_ids)) def _retrieve_candidates_for_all_queries(self, model_name: str, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]: drug_ids_set = set(drug_ids) relevant_indices = {i for i, m in enumerate(self.state.meta) if m.get("drug_id", "") in drug_ids_set} if not relevant_indices: return [] all_fused_candidates: Dict[int, FusedCandidate] = {} for sub_q in sub_queries: expanded_q = self._expand_query_with_llm(model_name, sub_q, tuple(intents)) q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32") distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K) tokenized_query = list(jieba.cut(expanded_q)) bm25_scores = self.state.bm25.get_scores(tokenized_query) rel_idx = np.array(list(relevant_indices)) rel_scores = bm25_scores[rel_idx] top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]] doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel} candidate_scores: Dict[int, Dict[str, float]] = {} for i, dist in zip(sim_indices[0], distances[0]): if i in relevant_indices: candidate_scores[int(i)] = {"sem": 1.0 / (1.0 + float(dist)), "bm": 0.0} for i, score in doc_to_bm25_score.items(): if i in relevant_indices: candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score if not candidate_scores: continue keys = list(candidate_scores.keys()) sem_scores = np.array([candidate_scores[k]['sem'] for k in keys]) bm_scores = np.array([candidate_scores[k]['bm'] for k in keys]) def norm(x): rng = x.max() - x.min() return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x) sem_n, bm_n = norm(sem_scores), norm(bm_scores) for idx, k in enumerate(keys): fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4 if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score: all_fused_candidates[k] = FusedCandidate(idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]) return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True) def _expand_query_with_llm(self, model_name: str, query: str, intents: tuple) -> str: if not intents: return query prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query) expanded = self._llm_call(model_name, [{"role": "user", "content": prompt}]) return expanded.strip() or query def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]: if not intents: return results prioritized_sections = set(section for intent in intents for section in INTENT_TO_SECTION.get(intent, [])) if not prioritized_sections: return results prioritized_results = [res for res in results if res.meta.get("section", "") in prioritized_sections] other_results = [res for res in results if res not in prioritized_results] return prioritized_results + other_results def _build_context(self, reranked_results: List[RerankResult]) -> str: context = "" for res in reranked_results: if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]: break context += res.text + "\n\n" return context.strip() def _generate_answers(self, model_name: str, query: str, context: str, intents: List[str], template_prefix: str) -> Dict[str, str]: additional_instruction = f"重點關注以下意圖: {', '.join(intents)}" if intents else "" max_tokens_concise = 512 prompt_concise = PROMPT_TEMPLATES[f"{template_prefix}_concise"].format(additional_instruction=additional_instruction, context=context, query=query) concise = self._llm_call(model_name, [{"role": "user", "content": prompt_concise}], max_tokens=max_tokens_concise) # 後處理去除*符號,讓LINE易讀 concise = concise.replace('*', '').replace('**', '') return {"concise": concise + f"\n{DISCLAIMER}"} def _safe_json_parse(self, s: str, default: Any = None) -> Any: try: return json.loads(s) except json.JSONDecodeError: return default # [MODIFIED] Return a list of strings (one per model) instead of a single giant string. def _format_responses(self, results: Dict) -> List[str]: formatted_responses = [] for model, res in results.items(): if "error" in res or "clarification" in res: formatted_responses.append(f"[{model}]:\n{res.get('clarification', res.get('error'))}") continue response_lines = [] for rag_type, answers in res.items(): rag_label = "with RAG" if rag_type == "with_rag" else "without RAG" response_lines.append(f"[{model} - {rag_label}]:\n{answers['concise']}") formatted_responses.append("\n".join(response_lines)) return formatted_responses # [NEW FUNCTION] Helper to split a long message into chunks for LINE API. def split_text_for_line(text: str, max_length: int = 4800) -> List[str]: """ Splits a string into chunks that are under LINE's message length limit. Tries to split at newlines for better readability. """ if len(text) <= max_length: return [text] chunks = [] while text: if len(text) <= max_length: chunks.append(text) break # Find the best place to split, preferably at a newline split_pos = text.rfind('\n', 0, max_length) if split_pos == -1: # No newline found, force split split_pos = max_length chunks.append(text[:split_pos]) text = text[split_pos:].lstrip() # Remove leading whitespace for the next chunk return chunks # ---------- FastAPI 應用程式設置 ---------- @asynccontextmanager async def lifespan(app: FastAPI): log.info("應用程式啟動...") global rag_pipeline rag_pipeline = RagPipeline() rag_pipeline.load_data() yield log.info("應用程式關閉...") app = FastAPI(lifespan=lifespan) rag_pipeline: RagPipeline = None # ---------- LINE Webhook 處理 ---------- @app.post("/webhook") async def line_webhook(request: Request): signature = request.headers.get('X-Line-Signature', '') body = await request.body() channel_secret = _require_env("CHANNEL_SECRET") hash_ = hmac.new(channel_secret.encode('utf-8'), body, hashlib.sha256).digest() expected_signature = base64.b64encode(hash_).decode() if not hmac.compare_digest(signature, expected_signature): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid signature") body_str = body.decode('utf-8') payload = json.loads(body_str) for event in payload.get('events', []): if event['type'] == 'message' and event['message']['type'] == 'text': user_text = event['message']['text'] reply_token = event['replyToken'] try: results = rag_pipeline.answer_question(user_text) # [MODIFIED] Handle LINE message length limit by splitting responses model_responses = rag_pipeline._format_responses(results) final_messages = [] for res_text in model_responses: chunks = split_text_for_line(res_text) for chunk in chunks: final_messages.append({'type': 'text', 'text': chunk}) # LINE allows a maximum of 5 messages per reply API call if not final_messages: # Fallback message if no response is generated messages_to_send = [{'type': 'text', 'text': '抱歉,目前無法處理您的請求。'}] else: messages_to_send = final_messages[:5] # 發送回覆到 LINE channel_access_token = _require_env("CHANNEL_ACCESS_TOKEN") headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {channel_access_token}'} data = {'replyToken': reply_token, 'messages': messages_to_send} response = requests.post('https://api.line.me/v2/bot/message/reply', headers=headers, json=data) response.raise_for_status() # Raise an exception for bad status codes except Exception as e: log.error(f"Error processing request: {e}", exc_info=True) # Send an error message back to the user try: channel_access_token = _require_env("CHANNEL_ACCESS_TOKEN") headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {channel_access_token}'} error_message = "抱歉,系統發生錯誤,請稍後再試。" data = {'replyToken': reply_token, 'messages': [{'type': 'text', 'text': error_message}]} requests.post('https://api.line.me/v2/bot/message/reply', headers=headers, json=data) except Exception as line_e: log.error(f"Failed to send error message to LINE: {line_e}") return {"status": "ok"} if __name__ == "__main__": _require_env("CHANNEL_SECRET") _require_env("CHANNEL_ACCESS_TOKEN") uvicorn.run(app, host="0.0.0.0", port=7860)