Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,310 +1,127 @@
|
|
| 1 |
-
# main.py
|
| 2 |
|
| 3 |
# ===============================================================
|
| 4 |
-
# 1. IMPORT THƯ VIỆN
|
| 5 |
# ===============================================================
|
| 6 |
import os
|
| 7 |
import torch
|
| 8 |
import gc
|
| 9 |
import re
|
| 10 |
-
import io
|
| 11 |
import logging
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
import
|
| 16 |
-
|
| 17 |
-
# Thư viện AI & Machine Learning
|
| 18 |
-
import transformers
|
| 19 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification
|
| 20 |
from datasets import load_dataset
|
| 21 |
from langchain_community.vectorstores import FAISS
|
| 22 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 23 |
-
|
| 24 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 25 |
-
import numpy as np
|
| 26 |
-
|
| 27 |
-
# Thư viện API
|
| 28 |
-
from fastapi import FastAPI, HTTPException
|
| 29 |
-
from pydantic import BaseModel
|
| 30 |
-
import uvicorn
|
| 31 |
|
| 32 |
# Set up logging
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# ===============================================================
|
| 40 |
-
# 2. KHỞI TẠO CÁC MODEL (PHẦN NẶNG NHẤT)
|
| 41 |
-
# ===============================================================
|
| 42 |
-
logger.info("Bắt đầu quá trình tải model...")
|
| 43 |
-
|
| 44 |
-
# Kiểm tra xem có GPU không
|
| 45 |
-
is_gpu_available = torch.cuda.is_available()
|
| 46 |
-
device = "cuda" if is_gpu_available else "cpu"
|
| 47 |
-
logger.info(f"Thiết bị được sử dụng: {device}")
|
| 48 |
-
|
| 49 |
-
# Model kiểm duyệt nội dung (chạy trên CPU để tiết kiệm VRAM)
|
| 50 |
-
logger.info("Đang tải model kiểm duyệt (moderation)...")
|
| 51 |
-
moderation_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
|
| 52 |
-
moderation_model = AutoModelForSequenceClassification.from_pretrained(
|
| 53 |
-
"facebook/roberta-hate-speech-dynabench-r4-target"
|
| 54 |
-
).to("cpu")
|
| 55 |
-
|
| 56 |
-
# Model Llama 2 chính
|
| 57 |
-
logger.info("Đang tải model Llama-2-7b-chat-hf...")
|
| 58 |
-
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
| 59 |
-
hf_token = os.environ.get("HF_TOKEN") # Lấy token từ Secret của Hugging Face Space
|
| 60 |
-
|
| 61 |
-
if not hf_token:
|
| 62 |
-
logger.warning("HF_TOKEN không được tìm thấy. Có thể không tải được Llama-2.")
|
| 63 |
-
|
| 64 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
|
| 65 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 66 |
-
|
| 67 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 68 |
-
model_id,
|
| 69 |
-
token=hf_token,
|
| 70 |
-
device_map="auto",
|
| 71 |
-
torch_dtype=torch.float16
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
# Pipeline phân tích cảm xúc
|
| 75 |
-
logger.info("Đang tải model phân tích cảm xúc (sentiment)...")
|
| 76 |
-
sentiment_analyzer = pipeline(
|
| 77 |
-
"sentiment-analysis",
|
| 78 |
-
model="cardiffnlp/twitter-roberta-base-sentiment-latest",
|
| 79 |
-
tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest",
|
| 80 |
-
device=0 if is_gpu_available else -1
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
# Pipeline phân tích cảm xúc chi tiết (emotion)
|
| 84 |
-
logger.info("Đang tải model phân tích cảm xúc chi tiết (emotion)...")
|
| 85 |
-
emotion_analyzer = pipeline(
|
| 86 |
-
"text-classification",
|
| 87 |
-
model="bhadresh-savani/distilbert-base-uncased-emotion",
|
| 88 |
-
top_k=None,
|
| 89 |
-
device=0 if is_gpu_available else -1
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
logger.info("Tất cả model đã được tải thành công!")
|
| 93 |
-
|
| 94 |
|
| 95 |
# ===============================================================
|
| 96 |
-
#
|
| 97 |
# ===============================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
# --- Các hằng số và patterns ---
|
| 100 |
-
CRISIS_PATTERNS = [
|
| 101 |
-
r"\bi (want to|need to|am going to|will) (die|kill myself|end it all)\b",
|
| 102 |
-
r"\bi'm going to (kill myself|end my life)\b",
|
| 103 |
-
r"\bplanning to end my life\b",
|
| 104 |
-
]
|
| 105 |
-
CONCERN_PATTERNS = [
|
| 106 |
-
r"\bi've been feeling (really )?(depressed|suicidal)\b",
|
| 107 |
-
r"\bi feel (hopeless|trapped|worthless)\b",
|
| 108 |
-
r"\bno one (cares|would miss me)\b",
|
| 109 |
-
]
|
| 110 |
-
MENTAL_HEALTH_RESOURCES = {
|
| 111 |
-
'crisis': ["National Suicide Prevention Lifeline (US): 988", "Crisis Text Line: Text HOME to 741741"],
|
| 112 |
-
'concern': ["SAMHSA Helpline (US): 1-800-662-HELP (4357)", "7 Cups (free online therapy): https://www.7cups.com"],
|
| 113 |
-
'general': ["Psychology Today Therapist Finder: https://www.psychologytoday.com"]
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
# --- Các hàm an toàn ---
|
| 117 |
-
def moderate_text(text):
|
| 118 |
-
inputs = moderation_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(moderation_model.device)
|
| 119 |
-
with torch.no_grad():
|
| 120 |
-
outputs = moderation_model(**inputs)
|
| 121 |
-
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 122 |
-
harmful_score = probs[0, 1].item()
|
| 123 |
-
return {'is_harmful': harmful_score > 0.7, 'score': harmful_score}
|
| 124 |
-
|
| 125 |
-
def sanitize_input(text):
|
| 126 |
-
text = re.sub(r'<[^>]+>', '', text)
|
| 127 |
-
return text.strip()[:1000]
|
| 128 |
-
|
| 129 |
-
def enhanced_crisis_detection(text):
|
| 130 |
-
text_lower = text.lower()
|
| 131 |
-
if any(re.search(pattern, text_lower) for pattern in CRISIS_PATTERNS): return "crisis"
|
| 132 |
-
if any(re.search(pattern, text_lower) for pattern in CONCERN_PATTERNS): return "concern"
|
| 133 |
-
return False
|
| 134 |
-
|
| 135 |
-
# --- Các hàm phân tích ---
|
| 136 |
-
def combined_sentiment_analysis(text):
|
| 137 |
-
urgency = enhanced_crisis_detection(text)
|
| 138 |
-
if urgency: return urgency, 1.0, [("crisis_detected", 1.0)]
|
| 139 |
-
try:
|
| 140 |
-
sentiment_result = sentiment_analyzer(text)[0]
|
| 141 |
-
sentiment = sentiment_result['label'].lower()
|
| 142 |
-
sent_score = sentiment_result['score']
|
| 143 |
-
emotion_results = emotion_analyzer(text)[0]
|
| 144 |
-
emotions = sorted(emotion_results, key=lambda x: x['score'], reverse=True)[:3]
|
| 145 |
-
emotions = [(emo['label'], emo['score']) for emo in emotions]
|
| 146 |
-
return sentiment, sent_score, emotions
|
| 147 |
-
except Exception as e:
|
| 148 |
-
logger.error(f"Lỗi phân tích cảm xúc: {e}")
|
| 149 |
-
return "neutral", 0.5, [("unknown", 0.0)]
|
| 150 |
-
|
| 151 |
-
# --- Hệ thống Retrieval-Augmented Generation (RAG) ---
|
| 152 |
-
def load_and_process_datasets(limit_per_dataset=200): # Giảm số lượng để tải nhanh hơn
|
| 153 |
-
# ... (Giữ nguyên logic hàm load_and_process_datasets của bạn) ...
|
| 154 |
-
# Vì hàm này khá dài, bạn có thể copy-paste lại từ code gốc của mình vào đây
|
| 155 |
-
# Hoặc để đơn giản hóa cho việc test, bạn có thể thay thế bằng dữ liệu giả
|
| 156 |
-
logger.info("Đang tải và xử lý datasets cho RAG...")
|
| 157 |
-
datasets = []
|
| 158 |
-
try:
|
| 159 |
-
empathetic = load_dataset("Estwld/empathetic_dialogues_llm", split=f"train[:{limit_per_dataset}]")
|
| 160 |
-
processed_empathetic = [f"Emotion: {ex['emotion']}. Situation: {ex['situation']}. Response: {ex['conversations'][0]['content']}" for ex in empathetic if ex['conversations']]
|
| 161 |
-
datasets.extend(processed_empathetic)
|
| 162 |
-
except Exception as e:
|
| 163 |
-
logger.warning(f"Không thể tải dataset Empathetic: {e}")
|
| 164 |
-
|
| 165 |
-
try:
|
| 166 |
-
mental_health = load_dataset("Amod/mental_health_counseling_conversations", split=f"train[:{limit_per_dataset}]")
|
| 167 |
-
processed_mental_health = [f"Context: {ex['Context']}. Response: {ex['Response']}" for ex in mental_health]
|
| 168 |
-
datasets.extend(processed_mental_health)
|
| 169 |
-
except Exception as e:
|
| 170 |
-
logger.warning(f"Không thể tải dataset Mental Health: {e}")
|
| 171 |
-
|
| 172 |
-
logger.info(f"Tổng số tài liệu RAG đã tải: {len(datasets)}")
|
| 173 |
-
return datasets
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
documents = load_and_process_datasets()
|
| 177 |
-
if documents:
|
| 178 |
-
vector_store = FAISS.from_texts(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2'))
|
| 179 |
-
retriever = vector_store.as_retriever(search_kwargs={'k': 2})
|
| 180 |
-
else:
|
| 181 |
-
retriever = None
|
| 182 |
-
logger.warning("Không có tài liệu nào cho RAG, retriever sẽ bị vô hiệu hóa.")
|
| 183 |
-
|
| 184 |
-
# --- Pipeline chính của Llama 2 ---
|
| 185 |
-
pipe = pipeline(
|
| 186 |
-
"text-generation",
|
| 187 |
-
model=model,
|
| 188 |
-
tokenizer=tokenizer,
|
| 189 |
-
do_sample=True,
|
| 190 |
-
temperature=0.7,
|
| 191 |
-
top_p=0.9,
|
| 192 |
-
max_new_tokens=512
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
# --- Quản lý hội thoại ---
|
| 196 |
-
class ConversationManager:
|
| 197 |
-
def __init__(self):
|
| 198 |
-
self.history = []
|
| 199 |
-
self.is_first_message = True
|
| 200 |
-
self.rate_limits = defaultdict(list)
|
| 201 |
-
def add_message(self, role, content): self.history.append({"role": role, "content": content})
|
| 202 |
-
def get_conversation_text(self): return "\n".join([f"{m['role']}: {m['content']}" for m in self.history])
|
| 203 |
-
def check_rate_limit(self, user_id="default"):
|
| 204 |
-
now = datetime.now()
|
| 205 |
-
recent = [req for req in self.rate_limits[user_id] if req > now - timedelta(minutes=1)]
|
| 206 |
-
if len(recent) >= 15: return False, "Rate limit exceeded"
|
| 207 |
-
self.rate_limits[user_id].append(now)
|
| 208 |
-
return True, ""
|
| 209 |
-
def reset(self):
|
| 210 |
-
self.history = []; self.is_first_message = True
|
| 211 |
-
return "Conversation reset"
|
| 212 |
-
|
| 213 |
-
conversation_manager = ConversationManager()
|
| 214 |
-
|
| 215 |
-
# --- Hàm tạo prompt và sinh response ---
|
| 216 |
-
def format_prompt_with_context(user_input, conv_history, retrieved_contexts):
|
| 217 |
-
context_text = ""
|
| 218 |
-
if retrieved_contexts:
|
| 219 |
-
context_text = "Dưới đây là một vài ví dụ để tham khảo:\n" + "\n".join(
|
| 220 |
-
[f"Ví dụ {i+1}: {ctx.page_content}" for i, ctx in enumerate(retrieved_contexts)]
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
prompt = f"""<s>[INST] <<SYS>>
|
| 224 |
-
Bạn là Athena, một trợ lý AI trị liệu tâm lý giàu lòng cảm thông bằng tiếng Việt. Hãy luôn duy trì một thái độ ấm áp, thấu hiểu và không phán xét. Sử dụng các kỹ thuật lắng nghe tích cực và phản hồi một cách sâu sắc.
|
| 225 |
-
{context_text}
|
| 226 |
-
Lịch sử hội thoại trước:
|
| 227 |
-
{conv_history}
|
| 228 |
-
<</SYS>>
|
| 229 |
-
Người dùng: {user_input} [/INST] Athena: """
|
| 230 |
-
return prompt
|
| 231 |
-
|
| 232 |
-
def generate_safe_response(user_input):
|
| 233 |
-
sanitized_input = sanitize_input(user_input)
|
| 234 |
-
if moderate_text(sanitized_input)['is_harmful']:
|
| 235 |
-
return "Tôi xin lỗi, tôi không thể xử lý nội dung có hại. Chúng ta hãy nói về điều gì đó tích cực hơn nhé."
|
| 236 |
-
|
| 237 |
-
urgency_level = enhanced_crisis_detection(sanitized_input)
|
| 238 |
-
retrieved_contexts = retriever.get_relevant_documents(sanitized_input) if retriever else []
|
| 239 |
-
conv_history = conversation_manager.get_conversation_text()
|
| 240 |
-
|
| 241 |
-
formatted_prompt = format_prompt_with_context(sanitized_input, conv_history, retrieved_contexts)
|
| 242 |
-
output = pipe(formatted_prompt, num_return_sequences=1)
|
| 243 |
-
response = output[0]['generated_text'].split("[/INST] Athena: ")[-1].strip()
|
| 244 |
-
|
| 245 |
-
if moderate_text(response)['is_harmful']:
|
| 246 |
-
return "Tôi xin lỗi, tôi không thể đưa ra phản hồi phù hợp lúc này. Bạn có muốn thảo luận về chủ đề khác không?"
|
| 247 |
-
|
| 248 |
-
if urgency_level:
|
| 249 |
-
resources = "\n".join(MENTAL_HEALTH_RESOURCES.get(urgency_level, []))
|
| 250 |
-
response += f"\n\n🚨 Tôi nhận thấy bạn đang gặp khó khăn. Các nguồn lực sau đây có thể giúp ích:\n{resources}"
|
| 251 |
-
|
| 252 |
-
return response
|
| 253 |
|
| 254 |
# ===============================================================
|
| 255 |
-
#
|
|
|
|
| 256 |
# ===============================================================
|
| 257 |
-
|
| 258 |
-
app = FastAPI(title="Athena AI Therapist API", description="API cho trợ lý trị liệu tâm lý Athena")
|
| 259 |
|
| 260 |
class PredictRequest(BaseModel):
|
| 261 |
-
user_id: str = "default-user"
|
| 262 |
user_input: str
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
@app.get("/", tags=["Health Check"])
|
| 265 |
def health_check():
|
| 266 |
-
"""
|
| 267 |
-
return {"status": "
|
| 268 |
|
| 269 |
@app.post("/predict", tags=["Core Logic"])
|
| 270 |
async def predict(request: PredictRequest):
|
| 271 |
"""
|
| 272 |
-
|
| 273 |
"""
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
|
| 288 |
-
conversation_manager.add_message("assistant", response_text)
|
| 289 |
|
| 290 |
-
return {
|
| 291 |
-
"response": response_text,
|
| 292 |
-
"sentiment_analysis": {
|
| 293 |
-
"sentiment": sentiment,
|
| 294 |
-
"score": round(score, 4),
|
| 295 |
-
"emotions": emo_list
|
| 296 |
-
}
|
| 297 |
-
}
|
| 298 |
except Exception as e:
|
| 299 |
logger.error(f"Lỗi tại endpoint /predict: {str(e)}")
|
| 300 |
-
raise HTTPException(status_code=500, detail="Đã xảy ra lỗi máy chủ nội bộ.")
|
| 301 |
-
|
| 302 |
-
@app.post("/reset", tags=["Utility"])
|
| 303 |
-
async def reset():
|
| 304 |
-
"""Reset lại lịch sử hội thoại."""
|
| 305 |
-
message = conversation_manager.reset()
|
| 306 |
-
return {"status": "success", "message": message}
|
| 307 |
-
|
| 308 |
-
# Dòng này để chạy local test, không cần thiết cho Hugging Face Space
|
| 309 |
-
# if __name__ == "__main__":
|
| 310 |
-
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
| 1 |
+
# main.py (đã sửa đổi với Lazy Loading)
|
| 2 |
|
| 3 |
# ===============================================================
|
| 4 |
+
# 1. IMPORT THƯ VIỆN & CÁC HẰNG SỐ (KHÔNG TẢI MODEL Ở ĐÂY)
|
| 5 |
# ===============================================================
|
| 6 |
import os
|
| 7 |
import torch
|
| 8 |
import gc
|
| 9 |
import re
|
|
|
|
| 10 |
import logging
|
| 11 |
+
from fastapi import FastAPI, HTTPException
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
import uvicorn
|
| 14 |
+
# ... (Thêm lại các import khác của bạn ở đây: transformers, datasets, langchain, etc.)
|
|
|
|
|
|
|
|
|
|
| 15 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification
|
| 16 |
from datasets import load_dataset
|
| 17 |
from langchain_community.vectorstores import FAISS
|
| 18 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 19 |
+
# ... và các import còn lại
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Set up logging
|
| 22 |
logging.basicConfig(level=logging.INFO)
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
+
# Tạo một "kho chứa" toàn cục để lưu các model sau khi được tải
|
| 26 |
+
# Ban đầu nó sẽ trống
|
| 27 |
+
model_cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# ===============================================================
|
| 30 |
+
# 2. TẠO MỘT HÀM ĐỂ TẢI TẤT CẢ MODEL VÀ RAG
|
| 31 |
# ===============================================================
|
| 32 |
+
def load_all_models():
|
| 33 |
+
"""
|
| 34 |
+
Hàm này sẽ tải tất cả các model và thiết lập RAG.
|
| 35 |
+
Nó chỉ thực sự chạy một lần duy nhất khi có request đầu tiên.
|
| 36 |
+
"""
|
| 37 |
+
# Kiểm tra xem model đã được tải chưa để tránh tải lại
|
| 38 |
+
if "is_loaded" in model_cache:
|
| 39 |
+
logger.info("Models đã được tải, bỏ qua.")
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
logger.info("Lần đầu khởi chạy, bắt đầu quá trình tải model (có thể mất vài phút)...")
|
| 43 |
+
is_gpu_available = torch.cuda.is_available()
|
| 44 |
+
device = "cuda" if is_gpu_available else "cpu"
|
| 45 |
+
logger.info(f"Thiết bị được sử dụng: {device}")
|
| 46 |
+
|
| 47 |
+
# Tải tất cả các model và lưu vào cache
|
| 48 |
+
logger.info("Đang tải model kiểm duyệt (moderation)...")
|
| 49 |
+
model_cache["moderation_tokenizer"] = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
|
| 50 |
+
model_cache["moderation_model"] = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target").to("cpu")
|
| 51 |
+
|
| 52 |
+
logger.info("Đang tải model Llama-2-7b-chat-hf...")
|
| 53 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 54 |
+
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
|
| 56 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto", torch_dtype=torch.float16)
|
| 57 |
+
model_cache["llama_pipe"] = pipeline("text-generation", model=model, tokenizer=tokenizer, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=512)
|
| 58 |
+
|
| 59 |
+
logger.info("Đang tải model phân tích cảm xúc (sentiment)...")
|
| 60 |
+
model_cache["sentiment_analyzer"] = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=0 if is_gpu_available else -1)
|
| 61 |
+
|
| 62 |
+
logger.info("Đang tải model phân tích cảm xúc chi tiết (emotion)...")
|
| 63 |
+
model_cache["emotion_analyzer"] = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", top_k=None, device=0 if is_gpu_available else -1)
|
| 64 |
+
|
| 65 |
+
# Tải và xử lý RAG
|
| 66 |
+
# LƯU Ý: Bạn cần copy lại hàm load_and_process_datasets và các hàm helper khác
|
| 67 |
+
# (sanitize_input, combined_sentiment_analysis, etc.) vào file này.
|
| 68 |
+
# Để ví dụ ngắn gọn, mình sẽ giả định chúng đã tồn tại.
|
| 69 |
+
# documents = load_and_process_datasets() # Hàm này của bạn
|
| 70 |
+
# if documents:
|
| 71 |
+
# vector_store = FAISS.from_texts(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2'))
|
| 72 |
+
# model_cache["retriever"] = vector_store.as_retriever(search_kwargs={'k': 2})
|
| 73 |
+
# else:
|
| 74 |
+
# model_cache["retriever"] = None
|
| 75 |
+
|
| 76 |
+
logger.info("Tất cả model đã được tải và thiết lập thành công!")
|
| 77 |
+
model_cache["is_loaded"] = True
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# ===============================================================
|
| 81 |
+
# 3. ĐỊNH NGHĨA APP VÀ ENDPOINT
|
| 82 |
+
# Server sẽ khởi động ngay lập tức vì không có gì nặng ở đây.
|
| 83 |
# ===============================================================
|
| 84 |
+
app = FastAPI(title="Athena AI Therapist API")
|
|
|
|
| 85 |
|
| 86 |
class PredictRequest(BaseModel):
|
|
|
|
| 87 |
user_input: str
|
| 88 |
|
| 89 |
+
@app.on_event("startup")
|
| 90 |
+
def startup_event():
|
| 91 |
+
"""Sự kiện này chỉ chạy 1 lần khi server bắt đầu."""
|
| 92 |
+
logger.info("Server FastAPI đã khởi động. Sẵn sàng nhận yêu cầu.")
|
| 93 |
+
logger.info("Các model sẽ được tải 'lười biếng' khi có yêu cầu /predict đầu tiên.")
|
| 94 |
+
|
| 95 |
@app.get("/", tags=["Health Check"])
|
| 96 |
def health_check():
|
| 97 |
+
"""Endpoint siêu nhẹ để Hugging Face kiểm tra sức khỏe."""
|
| 98 |
+
return {"status": "healthy", "models_loaded": model_cache.get("is_loaded", False)}
|
| 99 |
|
| 100 |
@app.post("/predict", tags=["Core Logic"])
|
| 101 |
async def predict(request: PredictRequest):
|
| 102 |
"""
|
| 103 |
+
Endpoint chính. Nó sẽ kích hoạt việc tải model nếu đây là lần chạy đầu tiên.
|
| 104 |
"""
|
| 105 |
+
# Bước quan trọng: Gọi hàm tải model.
|
| 106 |
+
# Nếu model đã được tải, nó sẽ bỏ qua ngay lập tức.
|
| 107 |
+
# Nếu chưa, nó sẽ chặn và tải ở đây.
|
| 108 |
+
load_all_models()
|
| 109 |
|
| 110 |
+
try:
|
| 111 |
+
# Bây giờ, sử dụng các model từ cache
|
| 112 |
+
# response_text = generate_safe_response(request.user_input, model_cache)
|
| 113 |
+
# Lưu ý: bạn sẽ cần sửa lại hàm generate_safe_response và các hàm khác
|
| 114 |
+
# để chúng nhận `model_cache` làm tham số thay vì dùng biến toàn cục.
|
| 115 |
|
| 116 |
+
# ---- VÍ DỤ TẠM THỜI ĐỂ TEST ----
|
| 117 |
+
prompt = f"User: {request.user_input}\nAthena:"
|
| 118 |
+
llama_pipe = model_cache["llama_pipe"]
|
| 119 |
+
result = llama_pipe(prompt)
|
| 120 |
+
response_text = result[0]['generated_text']
|
| 121 |
+
# --------------------------------
|
| 122 |
|
| 123 |
+
return {"response": response_text}
|
|
|
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
except Exception as e:
|
| 126 |
logger.error(f"Lỗi tại endpoint /predict: {str(e)}")
|
| 127 |
+
raise HTTPException(status_code=500, detail="Đã xảy ra lỗi máy chủ nội bộ.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|