FLOW2API / src /services /token_manager.py
xiaoh2018's picture
Upload 1108 files
33cfa2a verified
"""Token manager for Flow2API with AT auto-refresh"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import Optional, List
from ..core.database import Database
from ..core.models import Token, Project
from ..core.logger import debug_logger
from .flow_client import FlowClient
from .proxy_manager import ProxyManager
class TokenManager:
"""Token lifecycle manager with AT auto-refresh"""
def __init__(self, db: Database, flow_client: FlowClient):
self.db = db
self.flow_client = flow_client
self._lock = asyncio.Lock()
# ========== Token CRUD ==========
async def get_all_tokens(self) -> List[Token]:
"""Get all tokens"""
return await self.db.get_all_tokens()
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens"""
return await self.db.get_active_tokens()
async def get_token(self, token_id: int) -> Optional[Token]:
"""Get token by ID"""
return await self.db.get_token(token_id)
async def delete_token(self, token_id: int):
"""Delete token"""
await self.db.delete_token(token_id)
async def enable_token(self, token_id: int):
"""Enable a token and reset error count"""
# Enable the token
await self.db.update_token(token_id, is_active=True)
# Reset error count when enabling (only reset total error_count, keep today_error_count)
await self.db.reset_error_count(token_id)
async def disable_token(self, token_id: int):
"""Disable a token"""
await self.db.update_token(token_id, is_active=False)
# ========== Token添加 (支持Project创建) ==========
async def add_token(
self,
st: str,
project_id: Optional[str] = None,
project_name: Optional[str] = None,
remark: Optional[str] = None,
image_enabled: bool = True,
video_enabled: bool = True,
image_concurrency: int = -1,
video_concurrency: int = -1
) -> Token:
"""Add a new token
Args:
st: Session Token (必需)
project_id: 项目ID (可选,如果提供则直接使用,不创建新项目)
project_name: 项目名称 (可选,如果不提供则自动生成)
remark: 备注
image_enabled: 是否启用图片生成
video_enabled: 是否启用视频生成
image_concurrency: 图片并发限制
video_concurrency: 视频并发限制
Returns:
Token object
"""
# Step 1: 检查ST是否已存在
existing_token = await self.db.get_token_by_st(st)
if existing_token:
raise ValueError(f"Token 已存在(邮箱: {existing_token.email})")
# Step 2: 使用ST转换AT
debug_logger.log_info(f"[ADD_TOKEN] Converting ST to AT...")
try:
result = await self.flow_client.st_to_at(st)
at = result["access_token"]
expires = result.get("expires")
user_info = result.get("user", {})
email = user_info.get("email", "")
name = user_info.get("name", email.split("@")[0] if email else "")
# 解析过期时间
at_expires = None
if expires:
try:
at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00'))
except:
pass
except Exception as e:
raise ValueError(f"ST转AT失败: {str(e)}")
# Step 3: 查询余额
try:
credits_result = await self.flow_client.get_credits(at)
credits = credits_result.get("credits", 0)
user_paygate_tier = credits_result.get("userPaygateTier")
except:
credits = 0
user_paygate_tier = None
# Step 4: 处理Project ID和名称
if project_id:
# 用户提供了project_id,直接使用
debug_logger.log_info(f"[ADD_TOKEN] Using provided project_id: {project_id}")
if not project_name:
# 如果没有提供project_name,生成一个
now = datetime.now()
project_name = now.strftime("%b %d - %H:%M")
else:
# 用户没有提供project_id,需要创建新项目
if not project_name:
# 自动生成项目名称
now = datetime.now()
project_name = now.strftime("%b %d - %H:%M")
try:
project_id = await self.flow_client.create_project(st, project_name)
debug_logger.log_info(f"[ADD_TOKEN] Created new project: {project_name} (ID: {project_id})")
except Exception as e:
raise ValueError(f"创建项目失败: {str(e)}")
# Step 5: 创建Token对象
token = Token(
st=st,
at=at,
at_expires=at_expires,
email=email,
name=name,
remark=remark,
is_active=True,
credits=credits,
user_paygate_tier=user_paygate_tier,
current_project_id=project_id,
current_project_name=project_name,
image_enabled=image_enabled,
video_enabled=video_enabled,
image_concurrency=image_concurrency,
video_concurrency=video_concurrency
)
# Step 6: 保存到数据库
token_id = await self.db.add_token(token)
token.id = token_id
# Step 7: 保存Project到数据库
project = Project(
project_id=project_id,
token_id=token_id,
project_name=project_name,
tool_name="PINHOLE"
)
await self.db.add_project(project)
debug_logger.log_info(f"[ADD_TOKEN] Token added successfully (ID: {token_id}, Email: {email})")
return token
async def update_token(
self,
token_id: int,
st: Optional[str] = None,
at: Optional[str] = None,
at_expires: Optional[datetime] = None,
project_id: Optional[str] = None,
project_name: Optional[str] = None,
remark: Optional[str] = None,
image_enabled: Optional[bool] = None,
video_enabled: Optional[bool] = None,
image_concurrency: Optional[int] = None,
video_concurrency: Optional[int] = None
):
"""Update token (支持修改project_id和project_name)
当用户编辑保存token时,如果token未过期,自动清空429禁用状态
"""
update_fields = {}
if st is not None:
update_fields["st"] = st
if at is not None:
update_fields["at"] = at
if at_expires is not None:
update_fields["at_expires"] = at_expires
if project_id is not None:
update_fields["current_project_id"] = project_id
if project_name is not None:
update_fields["current_project_name"] = project_name
if remark is not None:
update_fields["remark"] = remark
if image_enabled is not None:
update_fields["image_enabled"] = image_enabled
if video_enabled is not None:
update_fields["video_enabled"] = video_enabled
if image_concurrency is not None:
update_fields["image_concurrency"] = image_concurrency
if video_concurrency is not None:
update_fields["video_concurrency"] = video_concurrency
# 检查token是否因429被禁用,如果是且未过期,则清空429状态
token = await self.db.get_token(token_id)
if token and token.ban_reason == "429_rate_limit":
# 检查token是否过期
is_expired = False
if token.at_expires:
now = datetime.now(timezone.utc)
if token.at_expires.tzinfo is None:
at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc)
else:
at_expires_aware = token.at_expires
is_expired = at_expires_aware <= now
# 如果未过期,清空429禁用状态
if not is_expired:
debug_logger.log_info(f"[UPDATE_TOKEN] Token {token_id} 编辑保存,清空429禁用状态")
update_fields["ban_reason"] = None
update_fields["banned_at"] = None
if update_fields:
await self.db.update_token(token_id, **update_fields)
# ========== AT自动刷新逻辑 (核心) ==========
async def is_at_valid(self, token_id: int) -> bool:
"""检查AT是否有效,如果无效或即将过期则自动刷新
Returns:
True if AT is valid or refreshed successfully
False if AT cannot be refreshed
"""
token = await self.db.get_token(token_id)
if not token:
return False
# 如果AT不存在,需要刷新
if not token.at:
debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT不存在,需要刷新")
return await self._refresh_at(token_id)
# 如果没有过期时间,假设需要刷新
if not token.at_expires:
debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT过期时间未知,尝试刷新")
return await self._refresh_at(token_id)
# 检查是否即将过期 (提前1小时刷新)
now = datetime.now(timezone.utc)
# 确保at_expires也是timezone-aware
if token.at_expires.tzinfo is None:
at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc)
else:
at_expires_aware = token.at_expires
time_until_expiry = at_expires_aware - now
if time_until_expiry.total_seconds() < 3600: # 1 hour (3600 seconds)
debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT即将过期 (剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新")
return await self._refresh_at(token_id)
# AT有效
return True
async def _refresh_at(self, token_id: int) -> bool:
"""内部方法: 刷新AT
Returns:
True if refresh successful, False otherwise
"""
async with self._lock:
token = await self.db.get_token(token_id)
if not token:
return False
try:
debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: 开始刷新AT...")
# 使用ST转AT
result = await self.flow_client.st_to_at(token.st)
new_at = result["access_token"]
expires = result.get("expires")
# 解析过期时间
new_at_expires = None
if expires:
try:
new_at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00'))
except:
pass
# 更新数据库
await self.db.update_token(
token_id,
at=new_at,
at_expires=new_at_expires
)
debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: AT刷新成功")
debug_logger.log_info(f" - 新过期时间: {new_at_expires}")
# 同时刷新credits
try:
credits_result = await self.flow_client.get_credits(new_at)
await self.db.update_token(
token_id,
credits=credits_result.get("credits", 0)
)
except:
pass
return True
except Exception as e:
debug_logger.log_error(f"[AT_REFRESH] Token {token_id}: AT刷新失败 - {str(e)}")
# 刷新失败,禁用Token
await self.disable_token(token_id)
return False
async def ensure_project_exists(self, token_id: int) -> str:
"""确保Token有可用的Project
Returns:
project_id
"""
token = await self.db.get_token(token_id)
if not token:
raise ValueError("Token not found")
# 如果已有project_id,直接返回
if token.current_project_id:
return token.current_project_id
# 创建新Project
now = datetime.now()
project_name = now.strftime("%b %d - %H:%M")
try:
project_id = await self.flow_client.create_project(token.st, project_name)
debug_logger.log_info(f"[PROJECT] Created project for token {token_id}: {project_name}")
# 更新Token
await self.db.update_token(
token_id,
current_project_id=project_id,
current_project_name=project_name
)
# 保存Project到数据库
project = Project(
project_id=project_id,
token_id=token_id,
project_name=project_name
)
await self.db.add_project(project)
return project_id
except Exception as e:
raise ValueError(f"Failed to create project: {str(e)}")
# ========== Token使用统计 ==========
async def record_usage(self, token_id: int, is_video: bool = False):
"""Record token usage"""
await self.db.update_token(token_id, use_count=1, last_used_at=datetime.now())
if is_video:
await self.db.increment_token_stats(token_id, "video")
else:
await self.db.increment_token_stats(token_id, "image")
async def record_error(self, token_id: int):
"""Record token error and auto-disable if threshold reached"""
await self.db.increment_token_stats(token_id, "error")
# Check if should auto-disable token (based on consecutive errors)
stats = await self.db.get_token_stats(token_id)
admin_config = await self.db.get_admin_config()
if stats and stats.consecutive_error_count >= admin_config.error_ban_threshold:
debug_logger.log_warning(
f"[TOKEN_BAN] Token {token_id} consecutive error count ({stats.consecutive_error_count}) "
f"reached threshold ({admin_config.error_ban_threshold}), auto-disabling"
)
await self.disable_token(token_id)
async def record_success(self, token_id: int):
"""Record successful request (reset consecutive error count)
This method resets error_count to 0, which is used for auto-disable threshold checking.
Note: today_error_count and historical statistics are NOT reset.
"""
await self.db.reset_error_count(token_id)
async def ban_token_for_429(self, token_id: int):
"""因429错误立即禁用token
Args:
token_id: Token ID
"""
debug_logger.log_warning(f"[429_BAN] 禁用Token {token_id} (原因: 429 Rate Limit)")
await self.db.update_token(
token_id,
is_active=False,
ban_reason="429_rate_limit",
banned_at=datetime.now(timezone.utc)
)
async def auto_unban_429_tokens(self):
"""自动解禁因429被禁用的token
规则:
- 距离禁用时间12小时后自动解禁
- 仅解禁未过期的token
- 仅解禁因429被禁用的token
"""
all_tokens = await self.db.get_all_tokens()
now = datetime.now(timezone.utc)
for token in all_tokens:
# 跳过非429禁用的token
if token.ban_reason != "429_rate_limit":
continue
# 跳过未禁用的token
if token.is_active:
continue
# 跳过没有禁用时间的token
if not token.banned_at:
continue
# 检查token是否已过期
if token.at_expires:
# 确保时区一致
if token.at_expires.tzinfo is None:
at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc)
else:
at_expires_aware = token.at_expires
# 如果已过期,跳过
if at_expires_aware <= now:
debug_logger.log_info(f"[AUTO_UNBAN] Token {token.id} 已过期,跳过解禁")
continue
# 确保banned_at时区一致
if token.banned_at.tzinfo is None:
banned_at_aware = token.banned_at.replace(tzinfo=timezone.utc)
else:
banned_at_aware = token.banned_at
# 检查是否已过12小时
time_since_ban = now - banned_at_aware
if time_since_ban.total_seconds() >= 12 * 3600: # 12小时
debug_logger.log_info(
f"[AUTO_UNBAN] 解禁Token {token.id} (禁用时间: {banned_at_aware}, "
f"已过 {time_since_ban.total_seconds() / 3600:.1f} 小时)"
)
await self.db.update_token(
token.id,
is_active=True,
ban_reason=None,
banned_at=None
)
# 重置错误计数
await self.db.reset_error_count(token.id)
# ========== 余额刷新 ==========
async def refresh_credits(self, token_id: int) -> int:
"""刷新Token余额
Returns:
credits
"""
token = await self.db.get_token(token_id)
if not token:
return 0
# 确保AT有效
if not await self.is_at_valid(token_id):
return 0
# 重新获取token (AT可能已刷新)
token = await self.db.get_token(token_id)
try:
result = await self.flow_client.get_credits(token.at)
credits = result.get("credits", 0)
# 更新数据库
await self.db.update_token(token_id, credits=credits)
return credits
except Exception as e:
debug_logger.log_error(f"Failed to refresh credits for token {token_id}: {str(e)}")
return 0