Spaces:
Running
Running
from datetime import datetime, timezone | |
from typing import Dict, Optional, Tuple | |
from utils.logger import logger | |
from utils.config import config, EnvMode | |
# Define subscription tiers and their monthly limits (in minutes) | |
SUBSCRIPTION_TIERS = { | |
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8}, | |
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, | |
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} | |
} | |
async def get_account_subscription(client, account_id: str) -> Optional[Dict]: | |
"""Get the current subscription for an account.""" | |
result = await client.schema('basejump').from_('billing_subscriptions') \ | |
.select('*') \ | |
.eq('account_id', account_id) \ | |
.eq('status', 'active') \ | |
.order('created', desc=True) \ | |
.limit(1) \ | |
.execute() | |
if result.data and len(result.data) > 0: | |
return result.data[0] | |
return None | |
async def calculate_monthly_usage(client, account_id: str) -> float: | |
"""Calculate total agent run minutes for the current month for an account.""" | |
# Get start of current month in UTC | |
now = datetime.now(timezone.utc) | |
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc) | |
# First get all threads for this account | |
threads_result = await client.table('threads') \ | |
.select('thread_id') \ | |
.eq('account_id', account_id) \ | |
.execute() | |
if not threads_result.data: | |
return 0.0 | |
thread_ids = [t['thread_id'] for t in threads_result.data] | |
# Then get all agent runs for these threads in current month | |
runs_result = await client.table('agent_runs') \ | |
.select('started_at, completed_at') \ | |
.in_('thread_id', thread_ids) \ | |
.gte('started_at', start_of_month.isoformat()) \ | |
.execute() | |
if not runs_result.data: | |
return 0.0 | |
# Calculate total minutes | |
total_seconds = 0 | |
now_ts = now.timestamp() | |
for run in runs_result.data: | |
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp() | |
if run['completed_at']: | |
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp() | |
else: | |
# For running jobs, use current time | |
end_time = now_ts | |
total_seconds += (end_time - start_time) | |
return total_seconds / 60 # Convert to minutes | |
async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]: | |
""" | |
Check if an account can run agents based on their subscription and usage. | |
Returns: | |
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info) | |
""" | |
if config.ENV_MODE == EnvMode.LOCAL: | |
logger.info("Running in local development mode - billing checks are disabled") | |
return True, "Local development mode - billing disabled", { | |
"price_id": "local_dev", | |
"plan_name": "Local Development", | |
"minutes_limit": "no limit" | |
} | |
# For staging/production, check subscription status | |
# Get current subscription | |
subscription = await get_account_subscription(client, account_id) | |
# If no subscription, they can use free tier | |
if not subscription: | |
subscription = { | |
'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier | |
'plan_name': 'free' | |
} | |
# if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC': | |
# return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription | |
# Get tier info | |
tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id']) | |
if not tier_info: | |
return False, "Invalid subscription tier", subscription | |
# Calculate current month's usage | |
current_usage = await calculate_monthly_usage(client, account_id) | |
# Check if within limits | |
if current_usage >= tier_info['minutes']: | |
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription | |
return True, "OK", subscription | |
# Helper function to get account ID from thread | |
async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]: | |
"""Get the account ID associated with a thread.""" | |
result = await client.table('threads') \ | |
.select('account_id') \ | |
.eq('thread_id', thread_id) \ | |
.limit(1) \ | |
.execute() | |
if result.data and len(result.data) > 0: | |
return result.data[0]['account_id'] | |
return None | |