DucThuanTran commited on
Commit
c4baf13
·
verified ·
1 Parent(s): 65841d7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +86 -269
main.py CHANGED
@@ -1,310 +1,127 @@
1
- # main.py
2
 
3
  # ===============================================================
4
- # 1. IMPORT THƯ VIỆN
5
  # ===============================================================
6
  import os
7
  import torch
8
  import gc
9
  import re
10
- import io
11
  import logging
12
- from datetime import datetime, timedelta
13
- from collections import defaultdict
14
- from itertools import chain
15
- import random
16
-
17
- # Thư viện AI & Machine Learning
18
- import transformers
19
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification
20
  from datasets import load_dataset
21
  from langchain_community.vectorstores import FAISS
22
  from langchain_community.embeddings import HuggingFaceEmbeddings
23
- from sentence_transformers import SentenceTransformer
24
- from sklearn.feature_extraction.text import TfidfVectorizer
25
- import numpy as np
26
-
27
- # Thư viện API
28
- from fastapi import FastAPI, HTTPException
29
- from pydantic import BaseModel
30
- import uvicorn
31
 
32
  # Set up logging
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
35
 
36
- # Ngăn chặn warning không cần thiết từ tokenizer
37
- transformers.logging.set_verbosity_error()
38
-
39
- # ===============================================================
40
- # 2. KHỞI TẠO CÁC MODEL (PHẦN NẶNG NHẤT)
41
- # ===============================================================
42
- logger.info("Bắt đầu quá trình tải model...")
43
-
44
- # Kiểm tra xem có GPU không
45
- is_gpu_available = torch.cuda.is_available()
46
- device = "cuda" if is_gpu_available else "cpu"
47
- logger.info(f"Thiết bị được sử dụng: {device}")
48
-
49
- # Model kiểm duyệt nội dung (chạy trên CPU để tiết kiệm VRAM)
50
- logger.info("Đang tải model kiểm duyệt (moderation)...")
51
- moderation_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
52
- moderation_model = AutoModelForSequenceClassification.from_pretrained(
53
- "facebook/roberta-hate-speech-dynabench-r4-target"
54
- ).to("cpu")
55
-
56
- # Model Llama 2 chính
57
- logger.info("Đang tải model Llama-2-7b-chat-hf...")
58
- model_id = "meta-llama/Llama-2-7b-chat-hf"
59
- hf_token = os.environ.get("HF_TOKEN") # Lấy token từ Secret của Hugging Face Space
60
-
61
- if not hf_token:
62
- logger.warning("HF_TOKEN không được tìm thấy. Có thể không tải được Llama-2.")
63
-
64
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
65
- tokenizer.pad_token = tokenizer.eos_token
66
-
67
- model = AutoModelForCausalLM.from_pretrained(
68
- model_id,
69
- token=hf_token,
70
- device_map="auto",
71
- torch_dtype=torch.float16
72
- )
73
-
74
- # Pipeline phân tích cảm xúc
75
- logger.info("Đang tải model phân tích cảm xúc (sentiment)...")
76
- sentiment_analyzer = pipeline(
77
- "sentiment-analysis",
78
- model="cardiffnlp/twitter-roberta-base-sentiment-latest",
79
- tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest",
80
- device=0 if is_gpu_available else -1
81
- )
82
-
83
- # Pipeline phân tích cảm xúc chi tiết (emotion)
84
- logger.info("Đang tải model phân tích cảm xúc chi tiết (emotion)...")
85
- emotion_analyzer = pipeline(
86
- "text-classification",
87
- model="bhadresh-savani/distilbert-base-uncased-emotion",
88
- top_k=None,
89
- device=0 if is_gpu_available else -1
90
- )
91
-
92
- logger.info("Tất cả model đã được tải thành công!")
93
-
94
 
95
  # ===============================================================
96
- # 3. CÁC HÀM XỬ LOGIC CỐT LÕI (giữ nguyên từ code của bạn)
97
  # ===============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # --- Các hằng số và patterns ---
100
- CRISIS_PATTERNS = [
101
- r"\bi (want to|need to|am going to|will) (die|kill myself|end it all)\b",
102
- r"\bi'm going to (kill myself|end my life)\b",
103
- r"\bplanning to end my life\b",
104
- ]
105
- CONCERN_PATTERNS = [
106
- r"\bi've been feeling (really )?(depressed|suicidal)\b",
107
- r"\bi feel (hopeless|trapped|worthless)\b",
108
- r"\bno one (cares|would miss me)\b",
109
- ]
110
- MENTAL_HEALTH_RESOURCES = {
111
- 'crisis': ["National Suicide Prevention Lifeline (US): 988", "Crisis Text Line: Text HOME to 741741"],
112
- 'concern': ["SAMHSA Helpline (US): 1-800-662-HELP (4357)", "7 Cups (free online therapy): https://www.7cups.com"],
113
- 'general': ["Psychology Today Therapist Finder: https://www.psychologytoday.com"]
114
- }
115
-
116
- # --- Các hàm an toàn ---
117
- def moderate_text(text):
118
- inputs = moderation_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(moderation_model.device)
119
- with torch.no_grad():
120
- outputs = moderation_model(**inputs)
121
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
122
- harmful_score = probs[0, 1].item()
123
- return {'is_harmful': harmful_score > 0.7, 'score': harmful_score}
124
-
125
- def sanitize_input(text):
126
- text = re.sub(r'<[^>]+>', '', text)
127
- return text.strip()[:1000]
128
-
129
- def enhanced_crisis_detection(text):
130
- text_lower = text.lower()
131
- if any(re.search(pattern, text_lower) for pattern in CRISIS_PATTERNS): return "crisis"
132
- if any(re.search(pattern, text_lower) for pattern in CONCERN_PATTERNS): return "concern"
133
- return False
134
-
135
- # --- Các hàm phân tích ---
136
- def combined_sentiment_analysis(text):
137
- urgency = enhanced_crisis_detection(text)
138
- if urgency: return urgency, 1.0, [("crisis_detected", 1.0)]
139
- try:
140
- sentiment_result = sentiment_analyzer(text)[0]
141
- sentiment = sentiment_result['label'].lower()
142
- sent_score = sentiment_result['score']
143
- emotion_results = emotion_analyzer(text)[0]
144
- emotions = sorted(emotion_results, key=lambda x: x['score'], reverse=True)[:3]
145
- emotions = [(emo['label'], emo['score']) for emo in emotions]
146
- return sentiment, sent_score, emotions
147
- except Exception as e:
148
- logger.error(f"Lỗi phân tích cảm xúc: {e}")
149
- return "neutral", 0.5, [("unknown", 0.0)]
150
-
151
- # --- Hệ thống Retrieval-Augmented Generation (RAG) ---
152
- def load_and_process_datasets(limit_per_dataset=200): # Giảm số lượng để tải nhanh hơn
153
- # ... (Giữ nguyên logic hàm load_and_process_datasets của bạn) ...
154
- # Vì hàm này khá dài, bạn có thể copy-paste lại từ code gốc của mình vào đây
155
- # Hoặc để đơn giản hóa cho việc test, bạn có thể thay thế bằng dữ liệu giả
156
- logger.info("Đang tải và xử lý datasets cho RAG...")
157
- datasets = []
158
- try:
159
- empathetic = load_dataset("Estwld/empathetic_dialogues_llm", split=f"train[:{limit_per_dataset}]")
160
- processed_empathetic = [f"Emotion: {ex['emotion']}. Situation: {ex['situation']}. Response: {ex['conversations'][0]['content']}" for ex in empathetic if ex['conversations']]
161
- datasets.extend(processed_empathetic)
162
- except Exception as e:
163
- logger.warning(f"Không thể tải dataset Empathetic: {e}")
164
-
165
- try:
166
- mental_health = load_dataset("Amod/mental_health_counseling_conversations", split=f"train[:{limit_per_dataset}]")
167
- processed_mental_health = [f"Context: {ex['Context']}. Response: {ex['Response']}" for ex in mental_health]
168
- datasets.extend(processed_mental_health)
169
- except Exception as e:
170
- logger.warning(f"Không thể tải dataset Mental Health: {e}")
171
-
172
- logger.info(f"Tổng số tài liệu RAG đã tải: {len(datasets)}")
173
- return datasets
174
-
175
-
176
- documents = load_and_process_datasets()
177
- if documents:
178
- vector_store = FAISS.from_texts(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2'))
179
- retriever = vector_store.as_retriever(search_kwargs={'k': 2})
180
- else:
181
- retriever = None
182
- logger.warning("Không có tài liệu nào cho RAG, retriever sẽ bị vô hiệu hóa.")
183
-
184
- # --- Pipeline chính của Llama 2 ---
185
- pipe = pipeline(
186
- "text-generation",
187
- model=model,
188
- tokenizer=tokenizer,
189
- do_sample=True,
190
- temperature=0.7,
191
- top_p=0.9,
192
- max_new_tokens=512
193
- )
194
-
195
- # --- Quản lý hội thoại ---
196
- class ConversationManager:
197
- def __init__(self):
198
- self.history = []
199
- self.is_first_message = True
200
- self.rate_limits = defaultdict(list)
201
- def add_message(self, role, content): self.history.append({"role": role, "content": content})
202
- def get_conversation_text(self): return "\n".join([f"{m['role']}: {m['content']}" for m in self.history])
203
- def check_rate_limit(self, user_id="default"):
204
- now = datetime.now()
205
- recent = [req for req in self.rate_limits[user_id] if req > now - timedelta(minutes=1)]
206
- if len(recent) >= 15: return False, "Rate limit exceeded"
207
- self.rate_limits[user_id].append(now)
208
- return True, ""
209
- def reset(self):
210
- self.history = []; self.is_first_message = True
211
- return "Conversation reset"
212
-
213
- conversation_manager = ConversationManager()
214
-
215
- # --- Hàm tạo prompt và sinh response ---
216
- def format_prompt_with_context(user_input, conv_history, retrieved_contexts):
217
- context_text = ""
218
- if retrieved_contexts:
219
- context_text = "Dưới đây là một vài ví dụ để tham khảo:\n" + "\n".join(
220
- [f"Ví dụ {i+1}: {ctx.page_content}" for i, ctx in enumerate(retrieved_contexts)]
221
- )
222
-
223
- prompt = f"""<s>[INST] <<SYS>>
224
- Bạn là Athena, một trợ lý AI trị liệu tâm lý giàu lòng cảm thông bằng tiếng Việt. Hãy luôn duy trì một thái độ ấm áp, thấu hiểu và không phán xét. Sử dụng các kỹ thuật lắng nghe tích cực và phản hồi một cách sâu sắc.
225
- {context_text}
226
- Lịch sử hội thoại trước:
227
- {conv_history}
228
- <</SYS>>
229
- Người dùng: {user_input} [/INST] Athena: """
230
- return prompt
231
-
232
- def generate_safe_response(user_input):
233
- sanitized_input = sanitize_input(user_input)
234
- if moderate_text(sanitized_input)['is_harmful']:
235
- return "Tôi xin lỗi, tôi không thể xử lý nội dung có hại. Chúng ta hãy nói về điều gì đó tích cực hơn nhé."
236
-
237
- urgency_level = enhanced_crisis_detection(sanitized_input)
238
- retrieved_contexts = retriever.get_relevant_documents(sanitized_input) if retriever else []
239
- conv_history = conversation_manager.get_conversation_text()
240
-
241
- formatted_prompt = format_prompt_with_context(sanitized_input, conv_history, retrieved_contexts)
242
- output = pipe(formatted_prompt, num_return_sequences=1)
243
- response = output[0]['generated_text'].split("[/INST] Athena: ")[-1].strip()
244
-
245
- if moderate_text(response)['is_harmful']:
246
- return "Tôi xin lỗi, tôi không thể đưa ra phản hồi phù hợp lúc này. Bạn có muốn thảo luận về chủ đề khác không?"
247
-
248
- if urgency_level:
249
- resources = "\n".join(MENTAL_HEALTH_RESOURCES.get(urgency_level, []))
250
- response += f"\n\n🚨 Tôi nhận thấy bạn đang gặp khó khăn. Các nguồn lực sau đây có thể giúp ích:\n{resources}"
251
-
252
- return response
253
 
254
  # ===============================================================
255
- # 4. ĐỊNH NGHĨA API VỚI FASTAPI
 
256
  # ===============================================================
257
-
258
- app = FastAPI(title="Athena AI Therapist API", description="API cho trợ lý trị liệu tâm lý Athena")
259
 
260
  class PredictRequest(BaseModel):
261
- user_id: str = "default-user"
262
  user_input: str
263
 
 
 
 
 
 
 
264
  @app.get("/", tags=["Health Check"])
265
  def health_check():
266
- """Kiểm tra xem API đang hoạt động không."""
267
- return {"status": "Athena is awake and listening..."}
268
 
269
  @app.post("/predict", tags=["Core Logic"])
270
  async def predict(request: PredictRequest):
271
  """
272
- Nhận input từ người dùng trả về phản hồi của Athena.
273
  """
274
- try:
275
- rate_ok, rate_msg = conversation_manager.check_rate_limit(request.user_id)
276
- if not rate_ok:
277
- raise HTTPException(status_code=429, detail=rate_msg)
278
 
279
- if not request.user_input or not request.user_input.strip():
280
- raise HTTPException(status_code=400, detail="User input không được để trống.")
 
 
 
281
 
282
- response_text = generate_safe_response(request.user_input)
283
-
284
- sentiment, score, emotions = combined_sentiment_analysis(request.user_input)
285
- emo_list = [{"label": e[0], "score": round(e[1], 4)} for e in emotions]
 
 
286
 
287
- conversation_manager.add_message("user", request.user_input)
288
- conversation_manager.add_message("assistant", response_text)
289
 
290
- return {
291
- "response": response_text,
292
- "sentiment_analysis": {
293
- "sentiment": sentiment,
294
- "score": round(score, 4),
295
- "emotions": emo_list
296
- }
297
- }
298
  except Exception as e:
299
  logger.error(f"Lỗi tại endpoint /predict: {str(e)}")
300
- raise HTTPException(status_code=500, detail="Đã xảy ra lỗi máy chủ nội bộ.")
301
-
302
- @app.post("/reset", tags=["Utility"])
303
- async def reset():
304
- """Reset lại lịch sử hội thoại."""
305
- message = conversation_manager.reset()
306
- return {"status": "success", "message": message}
307
-
308
- # Dòng này để chạy local test, không cần thiết cho Hugging Face Space
309
- # if __name__ == "__main__":
310
- # uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ # main.py (đã sửa đổi với Lazy Loading)
2
 
3
  # ===============================================================
4
+ # 1. IMPORT THƯ VIỆN & CÁC HẰNG SỐ (KHÔNG TẢI MODEL Ở ĐÂY)
5
  # ===============================================================
6
  import os
7
  import torch
8
  import gc
9
  import re
 
10
  import logging
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
+ import uvicorn
14
+ # ... (Thêm lại các import khác của bạn ở đây: transformers, datasets, langchain, etc.)
 
 
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification
16
  from datasets import load_dataset
17
  from langchain_community.vectorstores import FAISS
18
  from langchain_community.embeddings import HuggingFaceEmbeddings
19
+ # ... và các import còn lại
 
 
 
 
 
 
 
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
+ # Tạo một "kho chứa" toàn cục để lưu các model sau khi được tải
26
+ # Ban đầu nó sẽ trống
27
+ model_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # ===============================================================
30
+ # 2. TẠO MỘT HÀM ĐỂ TẢI TẤT CẢ MODEL RAG
31
  # ===============================================================
32
+ def load_all_models():
33
+ """
34
+ Hàm này sẽ tải tất cả các model và thiết lập RAG.
35
+ Nó chỉ thực sự chạy một lần duy nhất khi có request đầu tiên.
36
+ """
37
+ # Kiểm tra xem model đã được tải chưa để tránh tải lại
38
+ if "is_loaded" in model_cache:
39
+ logger.info("Models đã được tải, bỏ qua.")
40
+ return
41
+
42
+ logger.info("Lần đầu khởi chạy, bắt đầu quá trình tải model (có thể mất vài phút)...")
43
+ is_gpu_available = torch.cuda.is_available()
44
+ device = "cuda" if is_gpu_available else "cpu"
45
+ logger.info(f"Thiết bị được sử dụng: {device}")
46
+
47
+ # Tải tất cả các model và lưu vào cache
48
+ logger.info("Đang tải model kiểm duyệt (moderation)...")
49
+ model_cache["moderation_tokenizer"] = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
50
+ model_cache["moderation_model"] = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target").to("cpu")
51
+
52
+ logger.info("Đang tải model Llama-2-7b-chat-hf...")
53
+ hf_token = os.environ.get("HF_TOKEN")
54
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
56
+ model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto", torch_dtype=torch.float16)
57
+ model_cache["llama_pipe"] = pipeline("text-generation", model=model, tokenizer=tokenizer, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=512)
58
+
59
+ logger.info("Đang tải model phân tích cảm xúc (sentiment)...")
60
+ model_cache["sentiment_analyzer"] = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=0 if is_gpu_available else -1)
61
+
62
+ logger.info("Đang tải model phân tích cảm xúc chi tiết (emotion)...")
63
+ model_cache["emotion_analyzer"] = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", top_k=None, device=0 if is_gpu_available else -1)
64
+
65
+ # Tải và xử lý RAG
66
+ # LƯU Ý: Bạn cần copy lại hàm load_and_process_datasets và các hàm helper khác
67
+ # (sanitize_input, combined_sentiment_analysis, etc.) vào file này.
68
+ # Để ví dụ ngắn gọn, mình sẽ giả định chúng đã tồn tại.
69
+ # documents = load_and_process_datasets() # Hàm này của bạn
70
+ # if documents:
71
+ # vector_store = FAISS.from_texts(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2'))
72
+ # model_cache["retriever"] = vector_store.as_retriever(search_kwargs={'k': 2})
73
+ # else:
74
+ # model_cache["retriever"] = None
75
+
76
+ logger.info("Tất cả model đã được tải và thiết lập thành công!")
77
+ model_cache["is_loaded"] = True
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # ===============================================================
81
+ # 3. ĐỊNH NGHĨA APP ENDPOINT
82
+ # Server sẽ khởi động ngay lập tức vì không có gì nặng ở đây.
83
  # ===============================================================
84
+ app = FastAPI(title="Athena AI Therapist API")
 
85
 
86
  class PredictRequest(BaseModel):
 
87
  user_input: str
88
 
89
+ @app.on_event("startup")
90
+ def startup_event():
91
+ """Sự kiện này chỉ chạy 1 lần khi server bắt đầu."""
92
+ logger.info("Server FastAPI đã khởi động. Sẵn sàng nhận yêu cầu.")
93
+ logger.info("Các model sẽ được tải 'lười biếng' khi có yêu cầu /predict đầu tiên.")
94
+
95
  @app.get("/", tags=["Health Check"])
96
  def health_check():
97
+ """Endpoint siêu nhẹ để Hugging Face kiểm tra sức khỏe."""
98
+ return {"status": "healthy", "models_loaded": model_cache.get("is_loaded", False)}
99
 
100
  @app.post("/predict", tags=["Core Logic"])
101
  async def predict(request: PredictRequest):
102
  """
103
+ Endpoint chính. sẽ kích hoạt việc tải model nếu đây là lần chạy đầu tiên.
104
  """
105
+ # Bước quan trọng: Gọi hàm tải model.
106
+ # Nếu model đã được tải, sẽ bỏ qua ngay lập tức.
107
+ # Nếu chưa, nó sẽ chặn và tải ở đây.
108
+ load_all_models()
109
 
110
+ try:
111
+ # Bây giờ, sử dụng các model từ cache
112
+ # response_text = generate_safe_response(request.user_input, model_cache)
113
+ # Lưu ý: bạn sẽ cần sửa lại hàm generate_safe_response và các hàm khác
114
+ # để chúng nhận `model_cache` làm tham số thay vì dùng biến toàn cục.
115
 
116
+ # ---- VÍ DỤ TẠM THỜI ĐỂ TEST ----
117
+ prompt = f"User: {request.user_input}\nAthena:"
118
+ llama_pipe = model_cache["llama_pipe"]
119
+ result = llama_pipe(prompt)
120
+ response_text = result[0]['generated_text']
121
+ # --------------------------------
122
 
123
+ return {"response": response_text}
 
124
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
  logger.error(f"Lỗi tại endpoint /predict: {str(e)}")
127
+ raise HTTPException(status_code=500, detail="Đã xảy ra lỗi máy chủ nội bộ.")