| """Token manager for Flow2API with AT auto-refresh""" |
| import asyncio |
| from datetime import datetime, timedelta, timezone |
| from typing import Optional, List |
| from ..core.database import Database |
| from ..core.models import Token, Project |
| from ..core.logger import debug_logger |
| from .flow_client import FlowClient |
| from .proxy_manager import ProxyManager |
|
|
|
|
| class TokenManager: |
| """Token lifecycle manager with AT auto-refresh""" |
|
|
| def __init__(self, db: Database, flow_client: FlowClient): |
| self.db = db |
| self.flow_client = flow_client |
| self._lock = asyncio.Lock() |
|
|
| |
|
|
| async def get_all_tokens(self) -> List[Token]: |
| """Get all tokens""" |
| return await self.db.get_all_tokens() |
|
|
| async def get_active_tokens(self) -> List[Token]: |
| """Get all active tokens""" |
| return await self.db.get_active_tokens() |
|
|
| async def get_token(self, token_id: int) -> Optional[Token]: |
| """Get token by ID""" |
| return await self.db.get_token(token_id) |
|
|
| async def delete_token(self, token_id: int): |
| """Delete token""" |
| await self.db.delete_token(token_id) |
|
|
| async def enable_token(self, token_id: int): |
| """Enable a token and reset error count""" |
| |
| await self.db.update_token(token_id, is_active=True) |
| |
| await self.db.reset_error_count(token_id) |
|
|
| async def disable_token(self, token_id: int): |
| """Disable a token""" |
| await self.db.update_token(token_id, is_active=False) |
|
|
| |
|
|
| async def add_token( |
| self, |
| st: str, |
| project_id: Optional[str] = None, |
| project_name: Optional[str] = None, |
| remark: Optional[str] = None, |
| image_enabled: bool = True, |
| video_enabled: bool = True, |
| image_concurrency: int = -1, |
| video_concurrency: int = -1 |
| ) -> Token: |
| """Add a new token |
| |
| Args: |
| st: Session Token (必需) |
| project_id: 项目ID (可选,如果提供则直接使用,不创建新项目) |
| project_name: 项目名称 (可选,如果不提供则自动生成) |
| remark: 备注 |
| image_enabled: 是否启用图片生成 |
| video_enabled: 是否启用视频生成 |
| image_concurrency: 图片并发限制 |
| video_concurrency: 视频并发限制 |
| |
| Returns: |
| Token object |
| """ |
| |
| existing_token = await self.db.get_token_by_st(st) |
| if existing_token: |
| raise ValueError(f"Token 已存在(邮箱: {existing_token.email})") |
|
|
| |
| debug_logger.log_info(f"[ADD_TOKEN] Converting ST to AT...") |
| try: |
| result = await self.flow_client.st_to_at(st) |
| at = result["access_token"] |
| expires = result.get("expires") |
| user_info = result.get("user", {}) |
| email = user_info.get("email", "") |
| name = user_info.get("name", email.split("@")[0] if email else "") |
|
|
| |
| at_expires = None |
| if expires: |
| try: |
| at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) |
| except: |
| pass |
|
|
| except Exception as e: |
| raise ValueError(f"ST转AT失败: {str(e)}") |
|
|
| |
| try: |
| credits_result = await self.flow_client.get_credits(at) |
| credits = credits_result.get("credits", 0) |
| user_paygate_tier = credits_result.get("userPaygateTier") |
| except: |
| credits = 0 |
| user_paygate_tier = None |
|
|
| |
| if project_id: |
| |
| debug_logger.log_info(f"[ADD_TOKEN] Using provided project_id: {project_id}") |
| if not project_name: |
| |
| now = datetime.now() |
| project_name = now.strftime("%b %d - %H:%M") |
| else: |
| |
| if not project_name: |
| |
| now = datetime.now() |
| project_name = now.strftime("%b %d - %H:%M") |
|
|
| try: |
| project_id = await self.flow_client.create_project(st, project_name) |
| debug_logger.log_info(f"[ADD_TOKEN] Created new project: {project_name} (ID: {project_id})") |
| except Exception as e: |
| raise ValueError(f"创建项目失败: {str(e)}") |
|
|
| |
| token = Token( |
| st=st, |
| at=at, |
| at_expires=at_expires, |
| email=email, |
| name=name, |
| remark=remark, |
| is_active=True, |
| credits=credits, |
| user_paygate_tier=user_paygate_tier, |
| current_project_id=project_id, |
| current_project_name=project_name, |
| image_enabled=image_enabled, |
| video_enabled=video_enabled, |
| image_concurrency=image_concurrency, |
| video_concurrency=video_concurrency |
| ) |
|
|
| |
| token_id = await self.db.add_token(token) |
| token.id = token_id |
|
|
| |
| project = Project( |
| project_id=project_id, |
| token_id=token_id, |
| project_name=project_name, |
| tool_name="PINHOLE" |
| ) |
| await self.db.add_project(project) |
|
|
| debug_logger.log_info(f"[ADD_TOKEN] Token added successfully (ID: {token_id}, Email: {email})") |
| return token |
|
|
| async def update_token( |
| self, |
| token_id: int, |
| st: Optional[str] = None, |
| at: Optional[str] = None, |
| at_expires: Optional[datetime] = None, |
| project_id: Optional[str] = None, |
| project_name: Optional[str] = None, |
| remark: Optional[str] = None, |
| image_enabled: Optional[bool] = None, |
| video_enabled: Optional[bool] = None, |
| image_concurrency: Optional[int] = None, |
| video_concurrency: Optional[int] = None |
| ): |
| """Update token (支持修改project_id和project_name) |
| |
| 当用户编辑保存token时,如果token未过期,自动清空429禁用状态 |
| """ |
| update_fields = {} |
|
|
| if st is not None: |
| update_fields["st"] = st |
| if at is not None: |
| update_fields["at"] = at |
| if at_expires is not None: |
| update_fields["at_expires"] = at_expires |
| if project_id is not None: |
| update_fields["current_project_id"] = project_id |
| if project_name is not None: |
| update_fields["current_project_name"] = project_name |
| if remark is not None: |
| update_fields["remark"] = remark |
| if image_enabled is not None: |
| update_fields["image_enabled"] = image_enabled |
| if video_enabled is not None: |
| update_fields["video_enabled"] = video_enabled |
| if image_concurrency is not None: |
| update_fields["image_concurrency"] = image_concurrency |
| if video_concurrency is not None: |
| update_fields["video_concurrency"] = video_concurrency |
|
|
| |
| token = await self.db.get_token(token_id) |
| if token and token.ban_reason == "429_rate_limit": |
| |
| is_expired = False |
| if token.at_expires: |
| now = datetime.now(timezone.utc) |
| if token.at_expires.tzinfo is None: |
| at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) |
| else: |
| at_expires_aware = token.at_expires |
| is_expired = at_expires_aware <= now |
|
|
| |
| if not is_expired: |
| debug_logger.log_info(f"[UPDATE_TOKEN] Token {token_id} 编辑保存,清空429禁用状态") |
| update_fields["ban_reason"] = None |
| update_fields["banned_at"] = None |
|
|
| if update_fields: |
| await self.db.update_token(token_id, **update_fields) |
|
|
| |
|
|
| async def is_at_valid(self, token_id: int) -> bool: |
| """检查AT是否有效,如果无效或即将过期则自动刷新 |
| |
| Returns: |
| True if AT is valid or refreshed successfully |
| False if AT cannot be refreshed |
| """ |
| token = await self.db.get_token(token_id) |
| if not token: |
| return False |
|
|
| |
| if not token.at: |
| debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT不存在,需要刷新") |
| return await self._refresh_at(token_id) |
|
|
| |
| if not token.at_expires: |
| debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT过期时间未知,尝试刷新") |
| return await self._refresh_at(token_id) |
|
|
| |
| now = datetime.now(timezone.utc) |
| |
| if token.at_expires.tzinfo is None: |
| at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) |
| else: |
| at_expires_aware = token.at_expires |
|
|
| time_until_expiry = at_expires_aware - now |
|
|
| if time_until_expiry.total_seconds() < 3600: |
| debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT即将过期 (剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新") |
| return await self._refresh_at(token_id) |
|
|
| |
| return True |
|
|
|
|
| async def _refresh_at(self, token_id: int) -> bool: |
| """内部方法: 刷新AT |
| |
| 如果 AT 刷新失败(ST 可能过期),会尝试通过浏览器自动刷新 ST, |
| 然后重试 AT 刷新。 |
| |
| Returns: |
| True if refresh successful, False otherwise |
| """ |
| async with self._lock: |
| token = await self.db.get_token(token_id) |
| if not token: |
| return False |
|
|
| |
| result = await self._do_refresh_at(token_id, token.st) |
| if result: |
| return True |
|
|
| |
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: 第一次 AT 刷新失败,尝试自动更新 ST...") |
| |
| new_st = await self._try_refresh_st(token_id, token) |
| if new_st: |
| |
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: ST 已更新,重试 AT 刷新...") |
| result = await self._do_refresh_at(token_id, new_st) |
| if result: |
| return True |
|
|
| |
| debug_logger.log_error(f"[AT_REFRESH] Token {token_id}: 所有刷新尝试失败,禁用 Token") |
| await self.disable_token(token_id) |
| return False |
|
|
| async def _do_refresh_at(self, token_id: int, st: str) -> bool: |
| """执行 AT 刷新的核心逻辑 |
| |
| Args: |
| token_id: Token ID |
| st: Session Token |
| |
| Returns: |
| True if refresh successful AND AT is valid, False otherwise |
| """ |
| try: |
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: 开始刷新AT...") |
|
|
| |
| result = await self.flow_client.st_to_at(st) |
| new_at = result["access_token"] |
| expires = result.get("expires") |
|
|
| |
| new_at_expires = None |
| if expires: |
| try: |
| new_at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) |
| except: |
| pass |
|
|
| |
| await self.db.update_token( |
| token_id, |
| at=new_at, |
| at_expires=new_at_expires |
| ) |
|
|
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: AT刷新成功") |
| debug_logger.log_info(f" - 新过期时间: {new_at_expires}") |
|
|
| |
| try: |
| credits_result = await self.flow_client.get_credits(new_at) |
| await self.db.update_token( |
| token_id, |
| credits=credits_result.get("credits", 0) |
| ) |
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: AT 验证成功(余额: {credits_result.get('credits', 0)})") |
| return True |
| except Exception as verify_err: |
| |
| error_msg = str(verify_err) |
| if "401" in error_msg or "UNAUTHENTICATED" in error_msg: |
| debug_logger.log_warning(f"[AT_REFRESH] Token {token_id}: AT 验证失败 (401),ST 可能已过期") |
| return False |
| else: |
| |
| debug_logger.log_warning(f"[AT_REFRESH] Token {token_id}: AT 验证时发生非认证错误: {error_msg}") |
| return True |
|
|
| except Exception as e: |
| debug_logger.log_error(f"[AT_REFRESH] Token {token_id}: AT刷新失败 - {str(e)}") |
| return False |
|
|
| async def _try_refresh_st(self, token_id: int, token) -> Optional[str]: |
| """尝试通过浏览器刷新 Session Token |
| |
| 使用常驻 tab 获取新的 __Secure-next-auth.session-token |
| |
| Args: |
| token_id: Token ID |
| token: Token 对象 |
| |
| Returns: |
| 新的 ST 字符串,如果失败返回 None |
| """ |
| try: |
| from ..core.config import config |
|
|
| |
| if config.captcha_method != "personal": |
| debug_logger.log_info(f"[ST_REFRESH] 非 personal 模式,跳过 ST 自动刷新") |
| return None |
|
|
| if not token.current_project_id: |
| debug_logger.log_warning(f"[ST_REFRESH] Token {token_id} 没有 project_id,无法刷新 ST") |
| return None |
|
|
| debug_logger.log_info(f"[ST_REFRESH] Token {token_id}: 尝试通过浏览器刷新 ST...") |
|
|
| from .browser_captcha_personal import BrowserCaptchaService |
| service = await BrowserCaptchaService.get_instance(self.db) |
|
|
| new_st = await service.refresh_session_token(token.current_project_id) |
| if new_st and new_st != token.st: |
| |
| await self.db.update_token(token_id, st=new_st) |
| debug_logger.log_info(f"[ST_REFRESH] Token {token_id}: ST 已自动更新") |
| return new_st |
| elif new_st == token.st: |
| debug_logger.log_warning(f"[ST_REFRESH] Token {token_id}: 获取到的 ST 与原 ST 相同,可能登录已失效") |
| return None |
| else: |
| debug_logger.log_warning(f"[ST_REFRESH] Token {token_id}: 无法获取新 ST") |
| return None |
|
|
| except Exception as e: |
| debug_logger.log_error(f"[ST_REFRESH] Token {token_id}: 刷新 ST 失败 - {str(e)}") |
| return None |
|
|
| async def ensure_project_exists(self, token_id: int) -> str: |
| """确保Token有可用的Project |
| |
| Returns: |
| project_id |
| """ |
| token = await self.db.get_token(token_id) |
| if not token: |
| raise ValueError("Token not found") |
|
|
| |
| if token.current_project_id: |
| return token.current_project_id |
|
|
| |
| now = datetime.now() |
| project_name = now.strftime("%b %d - %H:%M") |
|
|
| try: |
| project_id = await self.flow_client.create_project(token.st, project_name) |
| debug_logger.log_info(f"[PROJECT] Created project for token {token_id}: {project_name}") |
|
|
| |
| await self.db.update_token( |
| token_id, |
| current_project_id=project_id, |
| current_project_name=project_name |
| ) |
|
|
| |
| project = Project( |
| project_id=project_id, |
| token_id=token_id, |
| project_name=project_name |
| ) |
| await self.db.add_project(project) |
|
|
| return project_id |
|
|
| except Exception as e: |
| raise ValueError(f"Failed to create project: {str(e)}") |
|
|
| |
|
|
| async def record_usage(self, token_id: int, is_video: bool = False): |
| """Record token usage""" |
| await self.db.update_token(token_id, use_count=1, last_used_at=datetime.now()) |
|
|
| if is_video: |
| await self.db.increment_token_stats(token_id, "video") |
| else: |
| await self.db.increment_token_stats(token_id, "image") |
|
|
| async def record_error(self, token_id: int): |
| """Record token error and auto-disable if threshold reached""" |
| await self.db.increment_token_stats(token_id, "error") |
|
|
| |
| stats = await self.db.get_token_stats(token_id) |
| admin_config = await self.db.get_admin_config() |
|
|
| if stats and stats.consecutive_error_count >= admin_config.error_ban_threshold: |
| debug_logger.log_warning( |
| f"[TOKEN_BAN] Token {token_id} consecutive error count ({stats.consecutive_error_count}) " |
| f"reached threshold ({admin_config.error_ban_threshold}), auto-disabling" |
| ) |
| await self.disable_token(token_id) |
|
|
| async def record_success(self, token_id: int): |
| """Record successful request (reset consecutive error count) |
| |
| This method resets error_count to 0, which is used for auto-disable threshold checking. |
| Note: today_error_count and historical statistics are NOT reset. |
| """ |
| await self.db.reset_error_count(token_id) |
|
|
| async def ban_token_for_429(self, token_id: int): |
| """因429错误立即禁用token |
| |
| Args: |
| token_id: Token ID |
| """ |
| debug_logger.log_warning(f"[429_BAN] 禁用Token {token_id} (原因: 429 Rate Limit)") |
| await self.db.update_token( |
| token_id, |
| is_active=False, |
| ban_reason="429_rate_limit", |
| banned_at=datetime.now(timezone.utc) |
| ) |
|
|
| async def auto_unban_429_tokens(self): |
| """自动解禁因429被禁用的token |
| |
| 规则: |
| - 距离禁用时间12小时后自动解禁 |
| - 仅解禁未过期的token |
| - 仅解禁因429被禁用的token |
| """ |
| all_tokens = await self.db.get_all_tokens() |
| now = datetime.now(timezone.utc) |
|
|
| for token in all_tokens: |
| |
| if token.ban_reason != "429_rate_limit": |
| continue |
|
|
| |
| if token.is_active: |
| continue |
|
|
| |
| if not token.banned_at: |
| continue |
|
|
| |
| if token.at_expires: |
| |
| if token.at_expires.tzinfo is None: |
| at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) |
| else: |
| at_expires_aware = token.at_expires |
|
|
| |
| if at_expires_aware <= now: |
| debug_logger.log_info(f"[AUTO_UNBAN] Token {token.id} 已过期,跳过解禁") |
| continue |
|
|
| |
| if token.banned_at.tzinfo is None: |
| banned_at_aware = token.banned_at.replace(tzinfo=timezone.utc) |
| else: |
| banned_at_aware = token.banned_at |
|
|
| |
| time_since_ban = now - banned_at_aware |
| if time_since_ban.total_seconds() >= 12 * 3600: |
| debug_logger.log_info( |
| f"[AUTO_UNBAN] 解禁Token {token.id} (禁用时间: {banned_at_aware}, " |
| f"已过 {time_since_ban.total_seconds() / 3600:.1f} 小时)" |
| ) |
| await self.db.update_token( |
| token.id, |
| is_active=True, |
| ban_reason=None, |
| banned_at=None |
| ) |
| |
| await self.db.reset_error_count(token.id) |
|
|
| |
|
|
| async def refresh_credits(self, token_id: int) -> int: |
| """刷新Token余额 |
| |
| Returns: |
| credits |
| """ |
| token = await self.db.get_token(token_id) |
| if not token: |
| return 0 |
|
|
| |
| if not await self.is_at_valid(token_id): |
| return 0 |
|
|
| |
| token = await self.db.get_token(token_id) |
|
|
| try: |
| result = await self.flow_client.get_credits(token.at) |
| credits = result.get("credits", 0) |
|
|
| |
| await self.db.update_token(token_id, credits=credits) |
|
|
| return credits |
| except Exception as e: |
| debug_logger.log_error(f"Failed to refresh credits for token {token_id}: {str(e)}") |
| return 0 |
|
|