cotienbot / src /bot.py
Anothervin1's picture
Update src/bot.py
2652625 verified
import hashlib
import httpx
import asyncio
from typing import Dict, Optional, Any, List, Union, Tuple
from tenacity import retry, stop_after_attempt, wait_fixed
from src.search import Search
from src.firestore_db import FirestoreDB
from src.gemini import Gemini
from src.cache import Cache
from src.firestore_lottery import FirestoreLottery
from src.lottery_core import LotteryCore
from src.lottery_analysis import LotteryAnalysis
from src.webhook import WebhookUpdate
from src.commands import (
StartCommand, HelpCommand, GeminiCommand, ExitCommand,
TrainCommand, TrainJsonCommand, LotteryCommand, LoadLotteryCommand,
SaveLotteryVectorsCommand, AddLotteryCommand, MigrateLotteryDataCommand,
CheckDatesCommand, ScheduleCommand, LotteryPositionCommand, DefaultCommand,
RegisterCommand, AuthCommand, LogoutCommand
)
from src.logger import logger
from config.settings import Settings
from datetime import datetime
import json
class Bot:
# Danh sách các lệnh nâng cao yêu cầu xác thực
ADVANCED_COMMANDS = [
"/load_lottery", "/lottery_position", "/train", "/train_json",
"/save_lottery_vectors", "/add_lottery", "/migrate_lottery_data",
"/check_dates", "/schedule", "/lottery"
]
def __init__(self):
logger.debug("[Bot] Initializing Bot")
self.db = FirestoreDB(credentials=Settings.FIRESTORE_CREDENTIALS)
self.cache = Cache()
self.search = Search(db=self.db)
self.gemini = Gemini(db=self.db, cache=self.cache)
self.firestore_lottery = FirestoreLottery(credentials=Settings.FIRESTORE_CREDENTIALS)
self.lottery_core = LotteryCore(
db=self.db,
cache=self.cache
)
self.lottery_analysis = LotteryAnalysis(
db=self.db,
cache=self.cache
)
self.message = None
self.gemini_mode: Dict[int, bool] = {}
self.authenticated_users: Dict[int, bool] = {}
self.commands = {
"/start": StartCommand(),
"/help": HelpCommand(),
"/auth": AuthCommand(),
"/Gemini": GeminiCommand(),
"/exit": ExitCommand(),
"/train": TrainCommand(),
"/train_json": TrainJsonCommand(),
"/lottery": LotteryCommand(),
"/load_lottery": LoadLotteryCommand(),
"/save_lottery_vectors": SaveLotteryVectorsCommand(),
"/add_lottery": AddLotteryCommand(),
"/migrate_lottery_data": MigrateLotteryDataCommand(),
"/check_dates": CheckDatesCommand(),
"/schedule": ScheduleCommand(),
"/lottery_position": LotteryPositionCommand(),
"/register": RegisterCommand(),
"/logout": LogoutCommand()
}
self.default_command = DefaultCommand()
self._last_response: List[str] = []
async def initialize(self):
"""Khởi tạo các thành phần của bot."""
logger.debug("[Bot] Starting initialization")
try:
await self.search.initialize(data_type="default")
await self.search.initialize(data_type="lottery")
await self.lottery_core.initialize()
# Đặt lại trạng thái xác thực cho tất cả người dùng
users = await self.db.get_all(data_type="users")
for user in users:
doc_id = user.get("id")
await self.db.set(
{"authenticated": False, "bot_authenticated": False},
data_type="users",
doc_id=doc_id,
merge=True
)
logger.info(f"[initialize] Reset authentication for doc_id={doc_id}")
logger.info("[Bot] Initialization completed")
except Exception as e:
logger.error(f"[Bot] Initialization error: {str(e)}", exc_info=True)
raise
async def handle_update(self, update: WebhookUpdate) -> bool:
logger.info(f"[handle_update] Received update: {update.update_id}")
try:
async with asyncio.timeout(120):
if not update.message:
logger.warning("[handle_update] No message in update")
return False
chat_id = update.message.get("chat", {}).get("id")
user_id = update.message.get("from", {}).get("id")
text = update.message.get("text", "")
if not text or not chat_id or not user_id:
logger.warning(f"[handle_update] Invalid message: chat_id={chat_id}, user_id={user_id}, text={text[:50]}...")
return False
self.message = update
await self.handle_message(chat_id, user_id, text, is_gradio=False)
return True
except asyncio.TimeoutError:
logger.error("[handle_update] Timeout")
return False
except Exception as e:
logger.error(f"[handle_update] Error: {str(e)}", exc_info=True)
return False
async def handle_message(self, chat_id: int, user_id: int, text: str, is_gradio: bool = False) -> List[str]:
logger.info(f"[handle_message] chat_id={chat_id}, user_id={user_id}, text={text[:64]}..., is_gradio={is_gradio}, raw_text={text}")
try:
async with asyncio.timeout(600):
self._last_response = [] # Reset last_response
# Define public commands that don't require authentication
public_commands = ["/start", "/help", "/auth", "/register", "/logout"]
# Check authentication for advanced commands
command = text.split(" ")[0].lower() if text else ""
if command not in public_commands:
is_authenticated, message = await self.check_user_auth(user_id, is_gradio)
logger.debug(f"[handle_message] Auth check: user_id={user_id}, is_authenticated={is_authenticated}")
if not is_authenticated:
self._last_response = [message]
await self.post_message(chat_id, self._last_response, is_gradio)
return self._last_response
# Handle /register separately to avoid cache
if command == "/register":
handler = self.commands.get("/register", self.default_command)
result = await handler.handle(self, chat_id, user_id, text, is_gradio)
result = self.format_response([result] if isinstance(result, str) else result, is_gradio)
self._last_response = result
await self.post_message(chat_id, result, is_gradio)
return result
# Check cache for non-register commands
command_obj = None
for cmd in sorted(self.commands.keys(), key=len, reverse=True):
if text.lower().startswith(cmd.lower()):
command_obj = self.commands[cmd]
logger.debug(f"[handle_message] Matched command: {cmd}")
break
handler = command_obj or self.default_command
logger.debug(f"[handle_message] Selected handler: {handler.__class__.__name__}")
cache_key = f"{user_id}:{text}:{handler.__class__.__name__}"
cached = await self.cache.get(cache_key, type="default")
if cached and isinstance(cached, dict) and "parts" in cached and "metadata" in cached:
if isinstance(cached["parts"], list) and len(cached["parts"]) == cached["metadata"]["count"]:
# Re-check authentication for cached advanced commands
if command in self.ADVANCED_COMMANDS:
is_authenticated, message = await self.check_user_auth(user_id, is_gradio)
if not is_authenticated:
self._last_response = [message]
await self.post_message(chat_id, self._last_response, is_gradio)
return self._last_response
logger.info(f"[handle_message] Cache hit: key={cache_key}")
self._last_response = cached["parts"]
await self.post_message(chat_id, self._last_response, is_gradio)
return self._last_response
# Process command
result = await handler.handle(self, chat_id, user_id, text, is_gradio)
logger.debug(f"[handle_message] Raw result type: {type(result)}, result: {str(result)[:100]}...")
if not result:
logger.warning(f"[handle_message] Handler failed: text={text[:64]}...")
response = ["Lỗi xử lý lệnh."]
self._last_response = response
await self.post_message(chat_id, response, is_gradio)
return response
# Format and cache response
result = self.format_response([result] if isinstance(result, str) else result, is_gradio)
cache_data = {
"parts": result,
"metadata": {"count": len(result), "timestamp": datetime.now().isoformat()}
}
await self.cache.set(cache_key, cache_data, type="default", ttl=Settings.CACHE_TTL or 3600)
self._last_response = result
await self.post_message(chat_id, result, is_gradio)
return result
except asyncio.TimeoutError:
logger.error(f"[handle_message] Timeout: text={text[:64]}...")
response = ["Lỗi: Timeout khi xử lý"]
self._last_response = response
await self.post_message(chat_id, response, is_gradio)
return response
except Exception as e:
logger.error(f"[handle_message] Error: {str(e)}", exc_info=True)
response = ["Lỗi hệ thống, vui lòng thử lại sau."]
self._last_response = response
await self.post_message(chat_id, response, is_gradio)
return response
def format_response(self, parts: List[str], is_gradio: bool = False) -> List[str]:
logger.debug(f"[format_response] parts_count={len(parts)}, is_gradio={is_gradio}")
formatted_parts = []
for part in parts:
if not isinstance(part, str):
part = str(part).strip()
if not part:
continue
if is_gradio:
part = part.replace('\n', '<br>')
formatted_parts.append(part)
if not formatted_parts:
formatted_parts = ["Không có nội dung."]
logger.debug(f"[format_response] Formatted parts: {len(formatted_parts)}")
return formatted_parts
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
async def post_message(self, chat_id: int, parts: List[str], is_gradio: bool = False) -> None:
logger.debug(f"[post_message] chat_id={chat_id}, parts_count={len(parts)}, is_gradio={is_gradio}")
if not parts:
logger.warning("[post_message] Cannot send empty parts")
return
max_length = 1900 if is_gradio else 4096
final_parts = []
for part in parts:
if not isinstance(part, str):
part = str(part)
if len(part.encode('utf-8')) <= max_length - 100:
final_parts.append(part)
else:
lines = part.split("<br>" if is_gradio else "\n")
current_chunk = []
current_length = 0
for line in lines:
line_length = len((line + ("<br>" if is_gradio else "\n")).encode('utf-8'))
if current_length + line_length <= max_length - 100:
current_chunk.append(line)
current_length += line_length
else:
if current_chunk:
final_parts.append(("<br>" if is_gradio else "\n").join(current_chunk))
current_chunk = [line]
current_length = line_length
if current_chunk:
final_parts.append(("<br>" if is_gradio else "\n").join(current_chunk))
if is_gradio:
self._last_response = final_parts
logger.debug(f"[post_message] Gradio responses: {len(final_parts)} parts")
else:
try:
async with asyncio.timeout(60):
async with httpx.AsyncClient() as client:
for part in final_parts:
resp = await client.post(
f"https://api.telegram.org/bot{Settings.TELEGRAM_TOKEN}/sendMessage",
json={"chat_id": chat_id, "text": part},
timeout=60.0
)
resp.raise_for_status()
logger.debug(f"[post_message] Message sent: chat_id={chat_id}, text={part[:64]}...")
except Exception as e:
logger.error(f"[post_message] Failed to send message: {str(e)}")
raise
def get_last_response(self) -> List[str]:
return self._last_response
async def authenticate_user(self, user_id: int, password: str, is_gradio: bool = False) -> bool:
logger.debug(f"[authenticate_user] user_id={user_id}, is_gradio={is_gradio}")
if password == Settings.BOT_PASSWORD:
try:
doc_id = f"gradio_{user_id}" if is_gradio else str(user_id)
await self.db.set(
{
"authenticated": True,
"bot_authenticated": True,
"last_auth": datetime.now().isoformat(),
"type": "gradio" if is_gradio else "telegram"
},
data_type="users",
doc_id=doc_id,
merge=True
)
self.authenticated_users[user_id] = True
logger.info(f"[authenticate_user] User {user_id} authenticated with bot")
return True
except Exception as e:
logger.error(f"[authenticate_user] Firestore error: {str(e)}")
return False
logger.warning(f"[authenticate_user] Invalid password for user_id={user_id}")
return False
async def is_user_authenticated(self, user_id: int, is_gradio: bool = False) -> bool:
if self.authenticated_users.get(user_id, False):
return True
try:
doc_id = f"gradio_{user_id}" if is_gradio else str(user_id)
doc = await self.db.get(data_type="users", doc_id=doc_id)
is_authenticated = doc.get("authenticated", False) and doc.get("bot_authenticated", False) if doc else False
if is_authenticated:
self.authenticated_users[user_id] = True
logger.debug(f"[is_user_authenticated] user_id={user_id}, is_gradio={is_gradio}, authenticated={is_authenticated}")
return is_authenticated
except Exception as e:
logger.error(f"[is_user_authenticated] Firestore error: {str(e)}")
return False
async def check_user_auth(self, user_id: int, is_gradio: bool) -> Tuple[bool, str]:
logger.debug(f"[check_user_auth] user_id={user_id}, is_gradio={is_gradio}")
doc_id = f"gradio_{user_id}" if is_gradio else str(user_id)
user_doc = await self.db.get(data_type="users", doc_id=doc_id)
if not user_doc or (is_gradio and user_doc.get("type") != "gradio"):
logger.error(f"[check_user_auth] User not found or invalid type: doc_id={doc_id}")
return False, "Lỗi: Vui lòng đăng ký trước."
if not user_doc.get("authenticated", False):
logger.error(f"[check_user_auth] User not authenticated: doc_id={doc_id}")
return False, "Lỗi: Vui lòng đăng nhập trước."
if not user_doc.get("bot_authenticated", False):
logger.error(f"[check_user_auth] User not bot-authenticated: doc_id={doc_id}")
return False, "Lỗi: Vui lòng xác thực bằng /auth <mật_khẩu_bot>."
last_auth = datetime.fromisoformat(user_doc.get("last_auth", "1970-01-01T00:00:00"))
if (datetime.now() - last_auth).total_seconds() > 3600: # 1 hour timeout
await self.db.set(
{"authenticated": False, "bot_authenticated": False},
data_type="users",
doc_id=doc_id,
merge=True
)
logger.error(f"[check_user_auth] Session expired: doc_id={doc_id}")
return False, "Lỗi: Phiên đăng nhập đã hết hạn, vui lòng xác thực lại bằng /auth."
logger.debug(f"[check_user_auth] User authenticated: doc_id={doc_id}")
return True, ""