|
|
"""Load balancing module""" |
|
|
import random |
|
|
from typing import Optional |
|
|
from ..core.models import Token |
|
|
from ..core.config import config |
|
|
from .token_manager import TokenManager |
|
|
from .token_lock import TokenLock |
|
|
|
|
|
class LoadBalancer: |
|
|
"""Token load balancer with random selection and image generation lock""" |
|
|
|
|
|
def __init__(self, token_manager: TokenManager): |
|
|
self.token_manager = token_manager |
|
|
|
|
|
self.token_lock = TokenLock(lock_timeout=config.image_timeout) |
|
|
|
|
|
async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False) -> Optional[Token]: |
|
|
""" |
|
|
Select a token using random load balancing |
|
|
|
|
|
Args: |
|
|
for_image_generation: If True, only select tokens that are not locked for image generation and have image_enabled=True |
|
|
for_video_generation: If True, filter out tokens with Sora2 quota exhausted (sora2_cooldown_until not expired), tokens that don't support Sora2, and tokens with video_enabled=False |
|
|
|
|
|
Returns: |
|
|
Selected token or None if no available tokens |
|
|
""" |
|
|
|
|
|
if config.at_auto_refresh_enabled: |
|
|
all_tokens = await self.token_manager.get_all_tokens() |
|
|
for token in all_tokens: |
|
|
if token.is_active and token.expiry_time: |
|
|
from datetime import datetime |
|
|
time_until_expiry = token.expiry_time - datetime.now() |
|
|
hours_until_expiry = time_until_expiry.total_seconds() / 3600 |
|
|
|
|
|
if hours_until_expiry <= 24: |
|
|
await self.token_manager.auto_refresh_expiring_token(token.id) |
|
|
|
|
|
active_tokens = await self.token_manager.get_active_tokens() |
|
|
|
|
|
if not active_tokens: |
|
|
return None |
|
|
|
|
|
|
|
|
if for_video_generation: |
|
|
from datetime import datetime |
|
|
available_tokens = [] |
|
|
for token in active_tokens: |
|
|
|
|
|
if not token.video_enabled: |
|
|
continue |
|
|
|
|
|
|
|
|
if not token.sora2_supported: |
|
|
continue |
|
|
|
|
|
|
|
|
if token.sora2_cooldown_until and token.sora2_cooldown_until <= datetime.now(): |
|
|
await self.token_manager.refresh_sora2_remaining_if_cooldown_expired(token.id) |
|
|
|
|
|
token = await self.token_manager.db.get_token(token.id) |
|
|
|
|
|
|
|
|
if token and token.sora2_cooldown_until and token.sora2_cooldown_until > datetime.now(): |
|
|
continue |
|
|
|
|
|
if token: |
|
|
available_tokens.append(token) |
|
|
|
|
|
if not available_tokens: |
|
|
return None |
|
|
|
|
|
active_tokens = available_tokens |
|
|
|
|
|
|
|
|
if for_image_generation: |
|
|
available_tokens = [] |
|
|
for token in active_tokens: |
|
|
|
|
|
if not token.image_enabled: |
|
|
continue |
|
|
|
|
|
if not await self.token_lock.is_locked(token.id): |
|
|
available_tokens.append(token) |
|
|
|
|
|
if not available_tokens: |
|
|
return None |
|
|
|
|
|
|
|
|
return random.choice(available_tokens) |
|
|
else: |
|
|
|
|
|
return random.choice(active_tokens) |
|
|
|