Spaces:
Running
Running
import hashlib | |
import httpx | |
import asyncio | |
from typing import Dict, Optional, Any, List, Union, Tuple | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
from src.search import Search | |
from src.firestore_db import FirestoreDB | |
from src.gemini import Gemini | |
from src.cache import Cache | |
from src.firestore_lottery import FirestoreLottery | |
from src.lottery_core import LotteryCore | |
from src.lottery_analysis import LotteryAnalysis | |
from src.webhook import WebhookUpdate | |
from src.commands import ( | |
StartCommand, HelpCommand, GeminiCommand, ExitCommand, | |
TrainCommand, TrainJsonCommand, LotteryCommand, LoadLotteryCommand, | |
SaveLotteryVectorsCommand, AddLotteryCommand, MigrateLotteryDataCommand, | |
CheckDatesCommand, ScheduleCommand, LotteryPositionCommand, DefaultCommand, | |
RegisterCommand, AuthCommand, LogoutCommand | |
) | |
from src.logger import logger | |
from config.settings import Settings | |
from datetime import datetime | |
import json | |
class Bot: | |
# Danh sách các lệnh nâng cao yêu cầu xác thực | |
ADVANCED_COMMANDS = [ | |
"/load_lottery", "/lottery_position", "/train", "/train_json", | |
"/save_lottery_vectors", "/add_lottery", "/migrate_lottery_data", | |
"/check_dates", "/schedule", "/lottery" | |
] | |
def __init__(self): | |
logger.debug("[Bot] Initializing Bot") | |
self.db = FirestoreDB(credentials=Settings.FIRESTORE_CREDENTIALS) | |
self.cache = Cache() | |
self.search = Search(db=self.db) | |
self.gemini = Gemini(db=self.db, cache=self.cache) | |
self.firestore_lottery = FirestoreLottery(credentials=Settings.FIRESTORE_CREDENTIALS) | |
self.lottery_core = LotteryCore( | |
db=self.db, | |
cache=self.cache | |
) | |
self.lottery_analysis = LotteryAnalysis( | |
db=self.db, | |
cache=self.cache | |
) | |
self.message = None | |
self.gemini_mode: Dict[int, bool] = {} | |
self.authenticated_users: Dict[int, bool] = {} | |
self.commands = { | |
"/start": StartCommand(), | |
"/help": HelpCommand(), | |
"/auth": AuthCommand(), | |
"/Gemini": GeminiCommand(), | |
"/exit": ExitCommand(), | |
"/train": TrainCommand(), | |
"/train_json": TrainJsonCommand(), | |
"/lottery": LotteryCommand(), | |
"/load_lottery": LoadLotteryCommand(), | |
"/save_lottery_vectors": SaveLotteryVectorsCommand(), | |
"/add_lottery": AddLotteryCommand(), | |
"/migrate_lottery_data": MigrateLotteryDataCommand(), | |
"/check_dates": CheckDatesCommand(), | |
"/schedule": ScheduleCommand(), | |
"/lottery_position": LotteryPositionCommand(), | |
"/register": RegisterCommand(), | |
"/logout": LogoutCommand() | |
} | |
self.default_command = DefaultCommand() | |
self._last_response: List[str] = [] | |
async def initialize(self): | |
"""Khởi tạo các thành phần của bot.""" | |
logger.debug("[Bot] Starting initialization") | |
try: | |
await self.search.initialize(data_type="default") | |
await self.search.initialize(data_type="lottery") | |
await self.lottery_core.initialize() | |
# Đặt lại trạng thái xác thực cho tất cả người dùng | |
users = await self.db.get_all(data_type="users") | |
for user in users: | |
doc_id = user.get("id") | |
await self.db.set( | |
{"authenticated": False, "bot_authenticated": False}, | |
data_type="users", | |
doc_id=doc_id, | |
merge=True | |
) | |
logger.info(f"[initialize] Reset authentication for doc_id={doc_id}") | |
logger.info("[Bot] Initialization completed") | |
except Exception as e: | |
logger.error(f"[Bot] Initialization error: {str(e)}", exc_info=True) | |
raise | |
async def handle_update(self, update: WebhookUpdate) -> bool: | |
logger.info(f"[handle_update] Received update: {update.update_id}") | |
try: | |
async with asyncio.timeout(120): | |
if not update.message: | |
logger.warning("[handle_update] No message in update") | |
return False | |
chat_id = update.message.get("chat", {}).get("id") | |
user_id = update.message.get("from", {}).get("id") | |
text = update.message.get("text", "") | |
if not text or not chat_id or not user_id: | |
logger.warning(f"[handle_update] Invalid message: chat_id={chat_id}, user_id={user_id}, text={text[:50]}...") | |
return False | |
self.message = update | |
await self.handle_message(chat_id, user_id, text, is_gradio=False) | |
return True | |
except asyncio.TimeoutError: | |
logger.error("[handle_update] Timeout") | |
return False | |
except Exception as e: | |
logger.error(f"[handle_update] Error: {str(e)}", exc_info=True) | |
return False | |
async def handle_message(self, chat_id: int, user_id: int, text: str, is_gradio: bool = False) -> List[str]: | |
logger.info(f"[handle_message] chat_id={chat_id}, user_id={user_id}, text={text[:64]}..., is_gradio={is_gradio}, raw_text={text}") | |
try: | |
async with asyncio.timeout(600): | |
self._last_response = [] # Reset last_response | |
# Define public commands that don't require authentication | |
public_commands = ["/start", "/help", "/auth", "/register", "/logout"] | |
# Check authentication for advanced commands | |
command = text.split(" ")[0].lower() if text else "" | |
if command not in public_commands: | |
is_authenticated, message = await self.check_user_auth(user_id, is_gradio) | |
logger.debug(f"[handle_message] Auth check: user_id={user_id}, is_authenticated={is_authenticated}") | |
if not is_authenticated: | |
self._last_response = [message] | |
await self.post_message(chat_id, self._last_response, is_gradio) | |
return self._last_response | |
# Handle /register separately to avoid cache | |
if command == "/register": | |
handler = self.commands.get("/register", self.default_command) | |
result = await handler.handle(self, chat_id, user_id, text, is_gradio) | |
result = self.format_response([result] if isinstance(result, str) else result, is_gradio) | |
self._last_response = result | |
await self.post_message(chat_id, result, is_gradio) | |
return result | |
# Check cache for non-register commands | |
command_obj = None | |
for cmd in sorted(self.commands.keys(), key=len, reverse=True): | |
if text.lower().startswith(cmd.lower()): | |
command_obj = self.commands[cmd] | |
logger.debug(f"[handle_message] Matched command: {cmd}") | |
break | |
handler = command_obj or self.default_command | |
logger.debug(f"[handle_message] Selected handler: {handler.__class__.__name__}") | |
cache_key = f"{user_id}:{text}:{handler.__class__.__name__}" | |
cached = await self.cache.get(cache_key, type="default") | |
if cached and isinstance(cached, dict) and "parts" in cached and "metadata" in cached: | |
if isinstance(cached["parts"], list) and len(cached["parts"]) == cached["metadata"]["count"]: | |
# Re-check authentication for cached advanced commands | |
if command in self.ADVANCED_COMMANDS: | |
is_authenticated, message = await self.check_user_auth(user_id, is_gradio) | |
if not is_authenticated: | |
self._last_response = [message] | |
await self.post_message(chat_id, self._last_response, is_gradio) | |
return self._last_response | |
logger.info(f"[handle_message] Cache hit: key={cache_key}") | |
self._last_response = cached["parts"] | |
await self.post_message(chat_id, self._last_response, is_gradio) | |
return self._last_response | |
# Process command | |
result = await handler.handle(self, chat_id, user_id, text, is_gradio) | |
logger.debug(f"[handle_message] Raw result type: {type(result)}, result: {str(result)[:100]}...") | |
if not result: | |
logger.warning(f"[handle_message] Handler failed: text={text[:64]}...") | |
response = ["Lỗi xử lý lệnh."] | |
self._last_response = response | |
await self.post_message(chat_id, response, is_gradio) | |
return response | |
# Format and cache response | |
result = self.format_response([result] if isinstance(result, str) else result, is_gradio) | |
cache_data = { | |
"parts": result, | |
"metadata": {"count": len(result), "timestamp": datetime.now().isoformat()} | |
} | |
await self.cache.set(cache_key, cache_data, type="default", ttl=Settings.CACHE_TTL or 3600) | |
self._last_response = result | |
await self.post_message(chat_id, result, is_gradio) | |
return result | |
except asyncio.TimeoutError: | |
logger.error(f"[handle_message] Timeout: text={text[:64]}...") | |
response = ["Lỗi: Timeout khi xử lý"] | |
self._last_response = response | |
await self.post_message(chat_id, response, is_gradio) | |
return response | |
except Exception as e: | |
logger.error(f"[handle_message] Error: {str(e)}", exc_info=True) | |
response = ["Lỗi hệ thống, vui lòng thử lại sau."] | |
self._last_response = response | |
await self.post_message(chat_id, response, is_gradio) | |
return response | |
def format_response(self, parts: List[str], is_gradio: bool = False) -> List[str]: | |
logger.debug(f"[format_response] parts_count={len(parts)}, is_gradio={is_gradio}") | |
formatted_parts = [] | |
for part in parts: | |
if not isinstance(part, str): | |
part = str(part).strip() | |
if not part: | |
continue | |
if is_gradio: | |
part = part.replace('\n', '<br>') | |
formatted_parts.append(part) | |
if not formatted_parts: | |
formatted_parts = ["Không có nội dung."] | |
logger.debug(f"[format_response] Formatted parts: {len(formatted_parts)}") | |
return formatted_parts | |
async def post_message(self, chat_id: int, parts: List[str], is_gradio: bool = False) -> None: | |
logger.debug(f"[post_message] chat_id={chat_id}, parts_count={len(parts)}, is_gradio={is_gradio}") | |
if not parts: | |
logger.warning("[post_message] Cannot send empty parts") | |
return | |
max_length = 1900 if is_gradio else 4096 | |
final_parts = [] | |
for part in parts: | |
if not isinstance(part, str): | |
part = str(part) | |
if len(part.encode('utf-8')) <= max_length - 100: | |
final_parts.append(part) | |
else: | |
lines = part.split("<br>" if is_gradio else "\n") | |
current_chunk = [] | |
current_length = 0 | |
for line in lines: | |
line_length = len((line + ("<br>" if is_gradio else "\n")).encode('utf-8')) | |
if current_length + line_length <= max_length - 100: | |
current_chunk.append(line) | |
current_length += line_length | |
else: | |
if current_chunk: | |
final_parts.append(("<br>" if is_gradio else "\n").join(current_chunk)) | |
current_chunk = [line] | |
current_length = line_length | |
if current_chunk: | |
final_parts.append(("<br>" if is_gradio else "\n").join(current_chunk)) | |
if is_gradio: | |
self._last_response = final_parts | |
logger.debug(f"[post_message] Gradio responses: {len(final_parts)} parts") | |
else: | |
try: | |
async with asyncio.timeout(60): | |
async with httpx.AsyncClient() as client: | |
for part in final_parts: | |
resp = await client.post( | |
f"https://api.telegram.org/bot{Settings.TELEGRAM_TOKEN}/sendMessage", | |
json={"chat_id": chat_id, "text": part}, | |
timeout=60.0 | |
) | |
resp.raise_for_status() | |
logger.debug(f"[post_message] Message sent: chat_id={chat_id}, text={part[:64]}...") | |
except Exception as e: | |
logger.error(f"[post_message] Failed to send message: {str(e)}") | |
raise | |
def get_last_response(self) -> List[str]: | |
return self._last_response | |
async def authenticate_user(self, user_id: int, password: str, is_gradio: bool = False) -> bool: | |
logger.debug(f"[authenticate_user] user_id={user_id}, is_gradio={is_gradio}") | |
if password == Settings.BOT_PASSWORD: | |
try: | |
doc_id = f"gradio_{user_id}" if is_gradio else str(user_id) | |
await self.db.set( | |
{ | |
"authenticated": True, | |
"bot_authenticated": True, | |
"last_auth": datetime.now().isoformat(), | |
"type": "gradio" if is_gradio else "telegram" | |
}, | |
data_type="users", | |
doc_id=doc_id, | |
merge=True | |
) | |
self.authenticated_users[user_id] = True | |
logger.info(f"[authenticate_user] User {user_id} authenticated with bot") | |
return True | |
except Exception as e: | |
logger.error(f"[authenticate_user] Firestore error: {str(e)}") | |
return False | |
logger.warning(f"[authenticate_user] Invalid password for user_id={user_id}") | |
return False | |
async def is_user_authenticated(self, user_id: int, is_gradio: bool = False) -> bool: | |
if self.authenticated_users.get(user_id, False): | |
return True | |
try: | |
doc_id = f"gradio_{user_id}" if is_gradio else str(user_id) | |
doc = await self.db.get(data_type="users", doc_id=doc_id) | |
is_authenticated = doc.get("authenticated", False) and doc.get("bot_authenticated", False) if doc else False | |
if is_authenticated: | |
self.authenticated_users[user_id] = True | |
logger.debug(f"[is_user_authenticated] user_id={user_id}, is_gradio={is_gradio}, authenticated={is_authenticated}") | |
return is_authenticated | |
except Exception as e: | |
logger.error(f"[is_user_authenticated] Firestore error: {str(e)}") | |
return False | |
async def check_user_auth(self, user_id: int, is_gradio: bool) -> Tuple[bool, str]: | |
logger.debug(f"[check_user_auth] user_id={user_id}, is_gradio={is_gradio}") | |
doc_id = f"gradio_{user_id}" if is_gradio else str(user_id) | |
user_doc = await self.db.get(data_type="users", doc_id=doc_id) | |
if not user_doc or (is_gradio and user_doc.get("type") != "gradio"): | |
logger.error(f"[check_user_auth] User not found or invalid type: doc_id={doc_id}") | |
return False, "Lỗi: Vui lòng đăng ký trước." | |
if not user_doc.get("authenticated", False): | |
logger.error(f"[check_user_auth] User not authenticated: doc_id={doc_id}") | |
return False, "Lỗi: Vui lòng đăng nhập trước." | |
if not user_doc.get("bot_authenticated", False): | |
logger.error(f"[check_user_auth] User not bot-authenticated: doc_id={doc_id}") | |
return False, "Lỗi: Vui lòng xác thực bằng /auth <mật_khẩu_bot>." | |
last_auth = datetime.fromisoformat(user_doc.get("last_auth", "1970-01-01T00:00:00")) | |
if (datetime.now() - last_auth).total_seconds() > 3600: # 1 hour timeout | |
await self.db.set( | |
{"authenticated": False, "bot_authenticated": False}, | |
data_type="users", | |
doc_id=doc_id, | |
merge=True | |
) | |
logger.error(f"[check_user_auth] Session expired: doc_id={doc_id}") | |
return False, "Lỗi: Phiên đăng nhập đã hết hạn, vui lòng xác thực lại bằng /auth." | |
logger.debug(f"[check_user_auth] User authenticated: doc_id={doc_id}") | |
return True, "" |