Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版) | |
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。 | |
此版本專注於效能、可維護性、健壯性與使用者體驗。 | |
""" | |
# ---------- 環境與快取設定 (應置於最前) ---------- | |
import os | |
import pathlib | |
os.environ.setdefault("HF_HOME", "/tmp/hf") | |
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") | |
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") | |
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")): | |
pathlib.Path(d).mkdir(parents=True, exist_ok=True) | |
# ---------- Python 標準函式庫 ---------- | |
import re | |
import hmac | |
import base64 | |
import hashlib | |
import pickle | |
import logging | |
import json | |
import textwrap | |
import time | |
import tenacity | |
from typing import List, Dict, Any, Optional, Tuple, Union | |
from functools import lru_cache | |
from dataclasses import dataclass, field | |
from contextlib import asynccontextmanager | |
import unicodedata | |
from collections import defaultdict | |
# ---------- 第三方函式庫 ---------- | |
import numpy as np | |
import pandas as pd | |
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks | |
import uvicorn | |
import jieba | |
from rank_bm25 import BM25Okapi | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import torch | |
from openai import OpenAI | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
import requests | |
from transformers import pipeline | |
# [MODIFIED] 限制 PyTorch 執行緒數量,避免在 CPU 環境下過度佔用資源 | |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1"))) | |
# ==== CONFIG (從環境變數載入,或使用預設值) ==== | |
# [MODIFIED] 新增環境變數健檢函式 | |
def _require_env(var: str) -> str: | |
v = os.getenv(var) | |
if not v: | |
raise RuntimeError(f"FATAL: Missing required environment variable: {var}") | |
return v | |
# [MODIFIED] 檢查 LLM 相關環境變數 | |
def _require_llm_config(): | |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"): | |
_require_env(k) | |
# MedGemma 模型直接硬編碼,如果需要替換,可以在這裡加入檢查 | |
# _require_env("MEDGEMMA_MODEL_NAME") # 如果 MEDGEMMA_MODEL_NAME 是環境變數 | |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv") | |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index") | |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl") | |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl") | |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20)) | |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30)) | |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30)) | |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh") | |
LLM_API_CONFIG = { | |
"base_url": os.getenv("LITELLM_BASE_URL"), | |
"api_key": os.getenv("LITELLM_API_KEY"), | |
"model": os.getenv("LM_MODEL") | |
} | |
MEDGEMMA_MODEL_NAME = "google/medgemma-4b-it" # 硬編碼 MedGemma 模型名稱 | |
LLM_MODEL_CONFIG = { | |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)), | |
"max_tokens": int(os.getenv("MAX_TOKENS", 1024)), | |
"temperature": float(os.getenv("TEMPERATURE", 0.0)), | |
"seed": int(os.getenv("LLM_SEED", 42)), # 新增 seed 以確保重現性 | |
} | |
INTENT_CATEGORIES = [ | |
"操作 (Administration)", "保存/攜帶 (Storage & Handling)", "副作用/異常 (Side Effects / Issues)", | |
"劑型相關 (Dosage Form Concerns)", "時間/併用 (Timing & Interaction)", "劑量調整 (Dosage Adjustment)", | |
"禁忌症/適應症 (Contraindications/Indications)" | |
] | |
# [新增] 意圖分類 → CSV section 對照表 | |
INTENT_TO_SECTION = { | |
"操作 (Administration)": ["用法用量", "病人使用須知"], | |
"保存/攜帶 (Storage & Handling)": ["包裝及儲存"], | |
"副作用/異常 (Side Effects / Issues)": ["不良反應", "警語與注意事項"], | |
"劑型相關 (Dosage Form Concerns)": ["劑型", "藥品外觀"], | |
"時間/併用 (Timing & Interaction)": ["用法用量"], | |
"劑量調整 (Dosage Adjustment)": ["用法用量"], | |
"禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"] | |
} | |
# 新增 REFERENCE_MAPPING | |
REFERENCE_MAPPING = { | |
"如何用藥?": "病人使用須知、用法用量", | |
"如何保存與攜帶?": "包裝及儲存", | |
"可能的副作用?": "警語與注意事項、不良反應", | |
"每次劑量多少?": "用法用量、藥袋上的醫囑", | |
"用藥時間?": "用法用量、藥袋上的醫囑", | |
} | |
# 新增反向映射,從 sections 找到 intents | |
SECTION_TO_INTENT = defaultdict(list) | |
for intent, sections in INTENT_TO_SECTION.items(): | |
for section in sections: | |
SECTION_TO_INTENT[section].append(intent) | |
DRUG_NAME_MAPPING = { | |
"fentanyl patch": "fentanyl", "spiriva respimat": "spiriva", "augmentin for syrup": "augmentin syrup", | |
"nitrostat": "nitroglycerin", "ozempic": "ozempic", "niflec": "niflec", | |
"fosamax": "fosamax", "humira": "humira", "premarin": "premarin", "smecta": "smecta", | |
} | |
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。" | |
PROMPT_TEMPLATES = { | |
"analyze_query": """ | |
請分析以下使用者問題,並完成以下兩個任務: | |
1. 將問題分解為1-3個核心的子問題。 | |
2. 從清單中選擇所有相關的意圖分類。 | |
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intents' (字串陣列) 兩個鍵。 | |
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}} | |
意圖分類清單: | |
{options} | |
使用者問題:{query} | |
""", | |
"expand_query": """ | |
請根據以下意圖:{intents},擴展這個查詢,加入相關同義詞或術語。 | |
原始查詢:{query} | |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。 | |
""", | |
"final_answer_concise": """ | |
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆: | |
一、 回覆準則 | |
嚴格依據資料: 所有回覆內容都必須完全來自提供的參考資料,禁止任何形式的捏造或引用外部資訊。 | |
資料不足: 若參考資料無法回答使用者的問題,請直接回覆:「根據提供的資料,無法回答您的問題。」 | |
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 | |
精簡扼要: 內容需極其簡潔,資訊完整,不要使用*符號,字數請嚴格控制在60字以內。 | |
二、 排版規範 | |
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 | |
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 | |
範例: | |
使用者問題: 請問普拿疼可以怎麼吃? | |
參考資料: 普拿疼成人建議劑量為一次1至2錠,每4至6小時服用一次,每日不超過8錠。 | |
AI回覆範例: | |
普拿疼成人劑量: | |
1-2錠/次 | |
每4-6小時 | |
每日≤8錠 | |
如有不適請立即就醫。 | |
{additional_instruction} | |
--- | |
參考資料: | |
{context} | |
--- | |
使用者問題:{query} | |
請直接輸出最終的答案: | |
""", | |
"final_answer_detailed": """ | |
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆: | |
一、 回覆準則 | |
嚴格依據資料: 所有回覆內容都必須完全來自提供的參考資料,禁止任何形式的捏造或引用外部資訊。 | |
資料不足: 若參考資料無法回答使用者的問題,請直接回覆:「根據提供的資料,無法回答您的問題。」 | |
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 | |
精簡扼要: 內容需簡潔但資訊完整,不要使用*符號,字數請控制在200字以內,提供更多細節解釋。 | |
二、 排版規範 | |
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 | |
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 | |
範例: | |
使用者問題: 請問普拿疼可以怎麼吃? | |
參考資料: 普拿疼成人建議劑量為一次1至2錠,每4至6小時服用一次,每日不超過8錠。 | |
AI回覆範例: | |
普拿疼成人建議劑量為: | |
- 一次服用1至2錠,視疼痛程度調整。 | |
- 每4至6小時服用一次,避免過頻。 | |
- 每日總量不超過8錠,以防副作用。 | |
如有不適請立即就醫。 | |
{additional_instruction} | |
--- | |
參考資料: | |
{context} | |
--- | |
使用者問題:{query} | |
請直接輸出最終的答案: | |
""", | |
"direct_answer_concise": """ | |
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆: | |
一、 回覆準則 | |
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 | |
精簡扼要: 內容需極其簡潔,資訊完整,不要使用*符號,字數請嚴格控制在60字以內。 | |
二、 排版規範 | |
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 | |
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 | |
範例: | |
使用者問題: 請問普拿疼可以怎麼吃? | |
AI回覆範例: | |
普拿疼成人劑量: | |
1-2錠/次 | |
每4-6小時 | |
每日≤8錠 | |
如有不適請立即就醫。 | |
{additional_instruction} | |
使用者問題:{query} | |
請直接輸出最終的答案: | |
""", | |
"direct_answer_detailed": """ | |
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆: | |
一、 回覆準則 | |
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 | |
精簡扼要: 內容需簡潔但資訊完整,不要使用*符號,字數請控制在200字以內,提供更多細節解釋。 | |
二、 排版規範 | |
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 | |
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 | |
範例: | |
使用者問題: 請問普拿疼可以怎麼吃? | |
AI回覆範例: | |
普拿疼成人建議劑量為: | |
- 一次服用1至2錠,視疼痛程度調整。 | |
- 每4至6小時服用一次,避免過頻。 | |
- 每日總量不超過8錠,以防副作用。 | |
如有不適請立即就醫。 | |
{additional_instruction} | |
使用者問題:{query} | |
請直接輸出最終的答案: | |
""" | |
} | |
# ---------- 日誌設定 ---------- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
log = logging.getLogger(__name__) | |
# [新增] 統一字串正規化函式 | |
def _norm(s: str) -> str: | |
"""統一化字串:NFKC 正規化、轉小寫、移除標點符號與空白。""" | |
s = unicodedata.normalize("NFKC", s) | |
return re.sub(r"[^\w\s]", "", s.lower()).strip() | |
class FusedCandidate: | |
idx: int | |
fused_score: float | |
sem_score: float | |
bm_score: float | |
class RerankResult: | |
idx: int | |
rerank_score: float | |
text: str | |
meta: Dict[str, Any] = field(default_factory=dict) | |
# ---------- 核心 RAG 邏輯 ---------- | |
class RagPipeline: | |
def __init__(self): | |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]: | |
raise ValueError("LLM API Key or Base URL is not configured.") | |
# OpenAI client for LITELLM | |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"]) | |
self.litellm_model_name = LLM_API_CONFIG["model"] | |
# MedGemma pipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log.info(f"載入 MedGemma 模型: {MEDGEMMA_MODEL_NAME} 至 {device}...") | |
try: | |
self.medgemma_pipe = pipeline( | |
"text-generation", | |
model=MEDGEMMA_MODEL_NAME, | |
torch_dtype=torch.bfloat16, | |
device=device, | |
) | |
log.info("MedGemma 模型載入成功。") | |
except Exception as e: | |
log.error(f"載入 MedGemma 模型失敗: {e}") | |
raise RuntimeError(f"MedGemma 模型載入失敗: {MEDGEMMA_MODEL_NAME}") | |
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding") | |
self.drug_name_to_ids: Dict[str, List[str]] = {} | |
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()} | |
self.state = type('state', (), {})() | |
# [新增] 澄清問題的 Prompt Template | |
self.CLARIFICATION_PROMPT = """ | |
請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精確地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境等。 | |
範例: | |
使用者問題:這個藥會怎麼樣? | |
澄清提問:您好,請問您指的是哪一種藥物呢? | |
使用者問題:請問這要吃多久? | |
澄清提問:請問您是想了解該藥品的建議療程長度嗎? | |
使用者問題:{query} | |
澄清提問:""" | |
def _load_model(self, model_class, model_name: str, model_type: str): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log.info(f"載入 {model_type} 模型:{model_name} 至 {device}...") | |
try: | |
return model_class(model_name, device=device) | |
except Exception as e: | |
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。") | |
try: | |
return model_class(model_name, device="cpu") | |
except Exception as e_cpu: | |
log.error(f"切換至 CPU 仍無法載入模型: {model_name}。請確認模型路徑或網路連線。錯誤訊息: {e_cpu}") | |
raise RuntimeError(f"模型載入失敗: {model_name}") | |
def load_data(self): | |
log.info("開始載入資料與模型...") | |
for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]: | |
if not pathlib.Path(path).exists(): | |
raise FileNotFoundError(f"必要的資料檔案不存在: {path}") | |
try: | |
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('') | |
for col in ("drug_name_norm", "drug_id"): | |
if col not in self.df_csv.columns: | |
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}") | |
self.drug_name_to_ids = self._build_drug_name_to_ids() | |
self._load_drug_name_vocabulary() | |
log.info("載入 FAISS 索引與句子資料...") | |
self.state.index = faiss.read_index(FAISS_INDEX) | |
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2) | |
if hasattr(self.state.index, "nprobe"): | |
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16")) | |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: | |
log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。") | |
with open(SENTENCES_PKL, "rb") as f: | |
data = pickle.load(f) | |
self.state.sentences = data["sentences"] | |
self.state.meta = data["meta"] | |
log.info("載入 BM25 索引...") | |
with open(BM25_PKL, "rb") as f: | |
bm25_data = pickle.load(f) | |
self.state.bm25 = bm25_data["bm25"] | |
if not isinstance(self.state.bm25, BM25Okapi): | |
raise ValueError("Loaded BM25 is not a BM25Okapi instance.") | |
except (FileNotFoundError, KeyError) as e: | |
log.exception(f"資料或索引檔案載入失敗: {e}") | |
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}") | |
log.info("所有模型與資料載入完成。") | |
# [新增] 將 drug_id 轉換為使用者友善的 drug_name_norm | |
def _get_drug_name_by_id(self, drug_id: str) -> Optional[str]: | |
"""從 drug_id 查找對應的 drug_name_norm。""" | |
row = self.df_csv[self.df_csv['drug_id'] == drug_id] | |
if not row.empty: | |
return row.iloc[0]['drug_name_norm'] | |
return None | |
def _find_drug_ids_from_name(self, query: str) -> List[str]: | |
q_norm_parts = set(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(query))) | |
drug_ids = set() | |
for part in q_norm_parts: | |
if part in self.drug_name_to_ids: | |
drug_ids.update(self.drug_name_to_ids[part]) | |
# [新增] 新增 drug_name_norm 至 drug_id 的反向查找,以支持上下文處理 | |
for drug_name, ids in self.drug_name_to_ids.items(): | |
if drug_name in _norm(query): | |
drug_ids.update(ids) | |
return sorted(list(drug_ids)) | |
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]: | |
mapping = {} | |
for _, row in self.df_csv.iterrows(): | |
drug_id = row['drug_id'] | |
zh_parts = list(jieba.cut(row['drug_name_zh'])) | |
en_parts = re.findall(r'[a-zA-Z0-9]+', row['drug_name_en'].lower() if row['drug_name_en'] else '') | |
norm_parts = re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(row['drug_name_norm'])) | |
all_parts = set(zh_parts + en_parts + norm_parts) | |
for part in all_parts: | |
part = part.strip() | |
if part and len(part) > 1: | |
mapping.setdefault(part, []).append(drug_id) | |
for alias, canonical_name in DRUG_NAME_MAPPING.items(): | |
if _norm(canonical_name) in _norm(row['drug_name_norm']): | |
mapping.setdefault(_norm(alias), []).append(drug_id) | |
for key in mapping: | |
mapping[key] = sorted(list(set(mapping[key]))) | |
return mapping | |
def _load_drug_name_vocabulary(self): | |
log.info("建立藥名詞庫...") | |
for _, row in self.df_csv.iterrows(): | |
norm_name = row['drug_name_norm'] | |
words = list(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', norm_name)) | |
for word in words: | |
if re.search(r'[\u4e00-\u9fff]', word): | |
self.drug_vocab["zh"].add(word) | |
else: | |
self.drug_vocab["en"].add(word) | |
for alias in DRUG_NAME_MAPPING: | |
if re.search(r'[\u4e00-\u9fff]', alias): | |
self.drug_vocab["zh"].add(alias) | |
else: | |
self.drug_vocab["en"].add(alias) | |
for word in self.drug_vocab["zh"]: | |
try: | |
if word not in jieba.dt.FREQ: | |
jieba.add_word(word, freq=2_000_000) | |
except Exception: | |
pass | |
# [MODIFIED] 兩個獨立的 LLM 調用函式,用於輸出比較 | |
def _litellm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str: | |
"""安全地呼叫 LITELLM API,並處理可能的回應內容為空錯誤。""" | |
log.info(f"LITELLM 呼叫開始. 模型: {self.litellm_model_name}, max_tokens: {max_tokens}, temperature: {temperature}, seed: {seed}") | |
start_time = time.time() | |
try: | |
response = self.llm_client.chat.completions.create( | |
model=self.litellm_model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
seed=seed, | |
) | |
end_time = time.time() | |
log.info(f"LITELLM 收到完整回應: {response.model_dump_json(indent=2)}") | |
if not response.choices or not response.choices[0].message.content: | |
log.warning("LITELLM 呼叫成功 (200 OK),但回傳內容為空。將回傳空字串。") | |
return "" | |
content = response.choices[0].message.content | |
log.info(f"LITELLM 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。") | |
return content | |
except Exception as e: | |
log.error(f"LITELLM API 呼叫失敗: {e}") | |
raise | |
def _medgemma_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str: | |
"""安全地呼叫 MedGemma 模型,並處理可能的回應內容為空錯誤。""" | |
log.info(f"MedGemma 呼叫開始. max_tokens: {max_tokens}, temperature: {temperature}, seed: {seed}") | |
start_time = time.time() | |
try: | |
# MedGemma pipeline requires a specific format for messages | |
converted_messages = [] | |
for msg in messages: | |
role = msg["role"] | |
content = [{"type": "text", "text": msg["content"]}] | |
converted_messages.append({"role": role, "content": content}) | |
output = self.medgemma_pipe( | |
converted_messages, | |
max_new_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens"], | |
temperature=temperature if temperature is not None else LLM_MODEL_CONFIG["temperature"], | |
# MedGemma pipeline 可能不直接支持seed,如果不支持,可移除或處理 | |
) | |
end_time = time.time() | |
if not output or not output[0]["generated_text"] or not output[0]["generated_text"][-1]["content"]: | |
log.warning("MedGemma 呼叫成功,但回傳內容為空。將回傳空字串。") | |
return "" | |
content = output[0]["generated_text"][-1]["content"] | |
log.info(f"MedGemma 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。") | |
return content | |
except Exception as e: | |
log.error(f"MedGemma 呼叫失敗: {e}") | |
raise | |
# [MODIFIED] 修改 answer_question 函式以回傳四種答案,每種包括簡潔版和詳細版 | |
def answer_question(self, q_orig: str) -> Tuple[Dict[str, Dict[str, str]], List[str]]: | |
start_time = time.time() | |
log.info(f"===== 處理新查詢: '{q_orig}' =====") | |
try: | |
drug_ids = self._find_drug_ids_from_name(q_orig) | |
if not drug_ids: | |
log.info("未從查詢中找到相關藥名,透過兩種 LLM 生成澄清性問題。") | |
clarification_litellm = self._generate_clarification_query_litellm(q_orig) | |
clarification_medgemma = self._generate_clarification_query_medgemma(q_orig) | |
log.info(f"澄清問題比較: LITELLM: '{clarification_litellm}', MedGemma: '{clarification_medgemma}'") | |
return {"clarification_litellm": clarification_litellm + f"\n{DISCLAIMER}", "clarification_medgemma": clarification_medgemma + f"\n{DISCLAIMER}"}, [] | |
log.info(f"步驟 1/4: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒") | |
step_start = time.time() | |
# 分開處理兩種 LLM 的分析 | |
analysis_litellm = self._analyze_query_litellm(q_orig) | |
analysis_medgemma = self._analyze_query_medgemma(q_orig) | |
sub_queries_litellm, intents_litellm = analysis_litellm.get("sub_queries", [q_orig]), analysis_litellm.get("intents", []) | |
sub_queries_medgemma, intents_medgemma = analysis_medgemma.get("sub_queries", [q_orig]), analysis_medgemma.get("intents", []) | |
log.info(f"意圖分析比較: LITELLM: {analysis_litellm}, MedGemma: {analysis_medgemma}") | |
if not intents_litellm and not intents_medgemma: | |
log.info("意圖分析失敗,透過兩種 LLM 生成澄清性問題。") | |
clarification_litellm = self._generate_clarification_query_litellm(q_orig) | |
clarification_medgemma = self._generate_clarification_query_medgemma(q_orig) | |
log.info(f"澄清問題比較: LITELLM: '{clarification_litellm}', MedGemma: '{clarification_medgemma}'") | |
return {"clarification_litellm": clarification_litellm + f"\n{DISCLAIMER}", "clarification_medgemma": clarification_medgemma + f"\n{DISCLAIMER}"}, drug_ids | |
log.info(f"步驟 2/4: 意圖分析完成。LITELLM 子問題: {sub_queries_litellm}, 意圖: {intents_litellm}。MedGemma 子問題: {sub_queries_medgemma}, 意圖: {intents_medgemma}。耗時: {time.time() - step_start:.2f} 秒") | |
step_start = time.time() | |
# 分開處理兩種 LLM 的檢索流程 | |
all_candidates_litellm = self._retrieve_candidates_for_all_queries_litellm(drug_ids, sub_queries_litellm, intents_litellm) | |
all_candidates_medgemma = self._retrieve_candidates_for_all_queries_medgemma(drug_ids, sub_queries_medgemma, intents_medgemma) | |
log.info(f"步驟 3/4: 檢索完成。LITELLM 找到 {len(all_candidates_litellm)} 個, MedGemma 找到 {len(all_candidates_medgemma)} 個。耗時: {time.time() - step_start:.2f} 秒") | |
step_start = time.time() | |
# LITELLM RAG | |
final_candidates_litellm = all_candidates_litellm[:TOP_K_SENTENCES] | |
reranked_results_litellm = [ | |
RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) | |
for c in final_candidates_litellm | |
] | |
prioritized_results_litellm = self._prioritize_context(reranked_results_litellm, intents_litellm) | |
context_litellm = self._build_context(prioritized_results_litellm) | |
# MedGemma RAG | |
final_candidates_medgemma = all_candidates_medgemma[:TOP_K_SENTENCES] | |
reranked_results_medgemma = [ | |
RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) | |
for c in final_candidates_medgemma | |
] | |
prioritized_results_medgemma = self._prioritize_context(reranked_results_medgemma, intents_medgemma) | |
context_medgemma = self._build_context(prioritized_results_medgemma) | |
log.info(f"步驟 4/4: 最終選出 LITELLM {len(reranked_results_litellm)} 個, MedGemma {len(reranked_results_medgemma)} 個。耗時: {time.time() - step_start:.2f} 秒") | |
step_start = time.time() | |
if not context_litellm and not context_medgemma: | |
log.info("沒有足夠的上下文來回答問題。") | |
return {"error": f"根據提供的資料,無法回答您的問題。\n{DISCLAIMER}"}, drug_ids | |
# 通用參數:temperature=0.0, seed=42 以確保重現性 | |
temp = LLM_MODEL_CONFIG["temperature"] | |
seed = LLM_MODEL_CONFIG["seed"] | |
max_tokens_concise = 256 # 簡潔版限制較小 | |
max_tokens_detailed = LLM_MODEL_CONFIG["max_tokens"] # 詳細版使用預設 | |
# 生成答案 - With RAG | |
answers_with_rag = self._generate_answers_with_rag(q_orig, context_litellm, intents_litellm, context_medgemma, intents_medgemma, temp, seed, max_tokens_concise, max_tokens_detailed) | |
# 生成答案 - Without RAG (直接用 query) | |
answers_without_rag = self._generate_answers_without_rag(q_orig, intents_litellm, intents_medgemma, temp, seed, max_tokens_concise, max_tokens_detailed) | |
log.info(f"答案生成完成。耗時: {time.time() - step_start:.2f} 秒") | |
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====") | |
# 合併返回 | |
return { | |
"LITELLM_with_RAG": answers_with_rag["LITELLM"], | |
"MedGemma_with_RAG": answers_with_rag["MedGemma"], | |
"LITELLM_without_RAG": answers_without_rag["LITELLM"], | |
"MedGemma_without_RAG": answers_without_rag["MedGemma"] | |
}, drug_ids | |
except Exception as e: | |
log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True) | |
return {"error": f"處理您的問題時發生內部錯誤,請稍後再試。\n{DISCLAIMER}"}, [] | |
def _analyze_query_litellm(self, query: str) -> Dict[str, Any]: | |
return self._analyze_query_generic(query, self._litellm_call) | |
def _analyze_query_medgemma(self, query: str) -> Dict[str, Any]: | |
return self._analyze_query_generic(query, self._medgemma_call) | |
def _analyze_query_generic(self, query: str, llm_call) -> Dict[str, Any]: | |
# 先檢查 REFERENCE_MAPPING | |
norm_query = _norm(query) | |
for ref_key, ref_value in REFERENCE_MAPPING.items(): | |
if _norm(ref_key) in norm_query: | |
sections = [s.strip() for s in ref_value.split(',')] | |
intents = [] | |
for section in sections: | |
intents.extend(SECTION_TO_INTENT.get(section, [])) | |
intents = list(set(intents)) # 去重 | |
if intents: | |
log.info(f"匹配 REFERENCE_MAPPING: '{ref_key}' -> intents: {intents}") | |
return {"sub_queries": [query], "intents": intents} | |
# 如果不匹配,才進行 LLM 意圖偵測 | |
prompt = PROMPT_TEMPLATES["analyze_query"].format( | |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES), | |
query=query | |
) | |
response_str = llm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]) | |
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []}) | |
def _generate_clarification_query_litellm(self, query: str) -> str: | |
prompt = self.CLARIFICATION_PROMPT.format(query=query) | |
return self._litellm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]).strip() | |
def _generate_clarification_query_medgemma(self, query: str) -> str: | |
prompt = self.CLARIFICATION_PROMPT.format(query=query) | |
return self._medgemma_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]).strip() | |
def _retrieve_candidates_for_all_queries_litellm(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]: | |
return self._retrieve_candidates_for_all_queries_generic(drug_ids, sub_queries, intents, self._expand_query_with_llm_litellm) | |
def _retrieve_candidates_for_all_queries_medgemma(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]: | |
return self._retrieve_candidates_for_all_queries_generic(drug_ids, sub_queries, intents, self._expand_query_with_llm_medgemma) | |
def _retrieve_candidates_for_all_queries_generic(self, drug_ids: List[str], sub_queries: List[str], intents: List[str], expand_func) -> List[FusedCandidate]: | |
drug_ids_set = set(map(str, drug_ids)) | |
if drug_ids_set: | |
relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set} | |
else: | |
relevant_indices = set(range(len(self.state.meta))) | |
if not relevant_indices: return [] | |
all_fused_candidates: Dict[int, FusedCandidate] = {} | |
for sub_q in sub_queries: | |
expanded_q = expand_func(sub_q, tuple(intents)) | |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32") | |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: | |
faiss.normalize_L2(q_emb) | |
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K) | |
tokenized_query = list(jieba.cut(expanded_q)) | |
bm25_scores = self.state.bm25.get_scores(tokenized_query) | |
rel_idx = np.fromiter(relevant_indices, dtype=int) | |
rel_scores = bm25_scores[rel_idx] | |
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]] | |
doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel} | |
candidate_scores: Dict[int, Dict[str, float]] = {} | |
def to_similarity(d: float) -> float: | |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: | |
return float(d) | |
else: | |
return 1.0 / (1.0 + float(d)) | |
for i, dist in zip(sim_indices[0], distances[0]): | |
if i in relevant_indices: | |
similarity = to_similarity(dist) | |
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0} | |
for i, score in doc_to_bm25_score.items(): | |
if i in relevant_indices: | |
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score | |
if not candidate_scores: continue | |
keys = list(candidate_scores.keys()) | |
sem_scores = np.array([candidate_scores[k]['sem'] for k in keys]) | |
bm_scores = np.array([candidate_scores[k]['bm'] for k in keys]) | |
def norm(x): | |
rng = x.max() - x.min() | |
return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x) | |
sem_n, bm_n = norm(sem_scores), norm(bm_scores) | |
for idx, k in enumerate(keys): | |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4 | |
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score: | |
all_fused_candidates[k] = FusedCandidate( | |
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx] | |
) | |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True) | |
def _expand_query_with_llm_litellm(self, query: str, intents: tuple) -> str: | |
if not intents: | |
return query | |
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query) | |
try: | |
expanded = self._litellm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]) | |
log.info(f"LITELLM 查詢擴展成功。原始: '{query}', 擴展後: '{expanded}'") | |
return expanded.strip() if expanded.strip() else query | |
except Exception as e: | |
log.error(f"LITELLM 查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。") | |
return query | |
def _expand_query_with_llm_medgemma(self, query: str, intents: tuple) -> str: | |
if not intents: | |
return query | |
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query) | |
try: | |
expanded = self._medgemma_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]) | |
log.info(f"MedGemma 查詢擴展成功。原始: '{query}', 擴展後: '{expanded}'") | |
return expanded.strip() if expanded.strip() else query | |
except Exception as e: | |
log.error(f"MedGemma 查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。") | |
return query | |
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]: | |
if not intents: | |
return results | |
prioritized_sections = set() | |
for intent in intents: | |
prioritized_sections.update(INTENT_TO_SECTION.get(intent, [])) | |
if not prioritized_sections: | |
return results | |
log.info(f"根據意圖 '{intents}' 優先處理章節: {prioritized_sections}") | |
prioritized_results = [] | |
other_results = [] | |
for res in results: | |
section = res.meta.get("section", "") | |
if section in prioritized_sections: | |
prioritized_results.append(res) | |
else: | |
other_results.append(res) | |
final_prioritized_list = prioritized_results + other_results | |
return final_prioritized_list | |
def _build_context(self, reranked_results: List[RerankResult]) -> str: | |
context = "" | |
for res in reranked_results: | |
if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]: break | |
context += res.text + "\n\n" | |
return context.strip() | |
def _safe_json_parse(self, s: str, default: Any = None) -> Any: | |
try: | |
return json.loads(s) | |
except json.JSONDecodeError: | |
log.warning(f"無法解析完整 JSON。嘗試從字串中提取: {s[:200]}...") | |
m = re.search(r'\{.*?\}', s, re.DOTALL) | |
if m: | |
try: | |
return json.loads(m.group(0)) | |
except json.JSONDecodeError: | |
log.warning(f"提取的 JSON 仍無法解析: {m.group(0)[:100]}...") | |
return default | |
# ---------- FastAPI 事件與路由 ---------- | |
class AppConfig: | |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN") | |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET") | |
rag_pipeline: Optional[RagPipeline] = None | |
# [MODIFIED] 全域使用者狀態快取, 儲存更詳細的資訊 | |
USER_STATE_CACHE = defaultdict(lambda: {"last_query": None, "last_drug_ids": []}) | |
async def lifespan(app: FastAPI): | |
_require_llm_config() | |
global rag_pipeline | |
rag_pipeline = RagPipeline() | |
rag_pipeline.load_data() | |
log.info("啟動完成,服務準備就緒。") | |
yield | |
log.info("服務關閉中。") | |
app = FastAPI(lifespan=lifespan) | |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks): | |
signature = request.headers.get("X-Line-Signature") | |
if not signature: | |
raise HTTPException(status_code=400, detail="Missing X-Line-Signature") | |
if not AppConfig.CHANNEL_SECRET: | |
log.error("CHANNEL_SECRET is not configured.") | |
raise HTTPException(status_code=500, detail="Server configuration error") | |
body = await request.body() | |
try: | |
hash = hmac.new(AppConfig.CHANNEL_SECRET.encode('utf-8'), body, hashlib.sha256) | |
expected_signature = base64.b64encode(hash.digest()).decode('utf-8') | |
except Exception as e: | |
log.error(f"Failed to generate signature: {e}") | |
raise HTTPException(status_code=500, detail="Signature generation error") | |
if not hmac.compare_digest(expected_signature, signature): | |
raise HTTPException(status_code=403, detail="Invalid signature") | |
try: | |
data = json.loads(body.decode('utf-8')) | |
except json.JSONDecodeError: | |
raise HTTPException(status_code=400, detail="Invalid JSON body") | |
for event in data.get("events", []): | |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text": | |
user_text = event.get("message", {}).get("text", "").strip() | |
source = event.get("source", {}) | |
stype = source.get("type") | |
target_id = source.get("userId") or source.get("groupId") or source.get("roomId") | |
if user_text and target_id: | |
background_tasks.add_task(rag_pipeline.process_user_query, stype, target_id, user_text) | |
return Response(status_code=status.HTTP_200_OK) | |
def line_api_call(endpoint: str, data: Dict): | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}" | |
} | |
try: | |
response = requests.post(f"https://api.line.me/v2/bot/message/{endpoint}", headers=headers, json=data, timeout=10) | |
response.raise_for_status() | |
except requests.exceptions.RequestException as e: | |
log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}") | |
raise | |
def line_push_generic(source_type: str, target_id: str, text: str): | |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]] | |
endpoint = "push" | |
data = {"to": target_id, "messages": messages} | |
line_api_call(endpoint, data) | |
# ---------- 執行 ---------- | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |