Spaces:
Running
Running
import os | |
import hashlib | |
import secrets | |
from typing import Dict, Optional | |
from datetime import datetime, timedelta | |
import logging | |
logger = logging.getLogger(__name__) | |
class APIKeyManager: | |
"""Manages API key authentication and rate limiting""" | |
def __init__(self): | |
self.api_keys = { | |
os.getenv("API_KEY_1", "your-secure-api-key-1"): { | |
"user": "user1", | |
"created": datetime.now(), | |
"last_used": None, | |
"request_count": 0 | |
}, | |
os.getenv("API_KEY_2", "your-secure-api-key-2"): { | |
"user": "user2", | |
"created": datetime.now(), | |
"last_used": None, | |
"request_count": 0 | |
} | |
} | |
self.rate_limits = {} # {api_key: {minute: count}} | |
self.max_requests_per_minute = int(os.getenv("RATE_LIMIT", "10")) | |
def validate_api_key(self, api_key: str) -> Optional[str]: | |
"""Validate API key and return user info""" | |
if api_key in self.api_keys: | |
self.api_keys[api_key]["last_used"] = datetime.now() | |
self.api_keys[api_key]["request_count"] += 1 | |
return self.api_keys[api_key]["user"] | |
return None | |
def check_rate_limit(self, api_key: str) -> bool: | |
"""Check if API key has exceeded rate limit""" | |
current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M") | |
if api_key not in self.rate_limits: | |
self.rate_limits[api_key] = {} | |
# Clean old entries (keep only last 5 minutes) | |
cutoff_time = datetime.now() - timedelta(minutes=5) | |
keys_to_remove = [] | |
for minute_key in self.rate_limits[api_key]: | |
try: | |
minute_time = datetime.strptime(minute_key, "%Y-%m-%d-%H-%M") | |
if minute_time < cutoff_time: | |
keys_to_remove.append(minute_key) | |
except ValueError: | |
keys_to_remove.append(minute_key) | |
for key in keys_to_remove: | |
del self.rate_limits[api_key][key] | |
# Check current minute | |
current_count = self.rate_limits[api_key].get(current_minute, 0) | |
if current_count >= self.max_requests_per_minute: | |
return False | |
# Increment counter | |
self.rate_limits[api_key][current_minute] = current_count + 1 | |
return True | |
def get_api_key_stats(self, api_key: str) -> Optional[Dict]: | |
"""Get statistics for an API key""" | |
if api_key in self.api_keys: | |
stats = self.api_keys[api_key].copy() | |
current_minute = datetime.now().strftime("%Y-%m-%d-%H-%M") | |
stats["current_minute_requests"] = self.rate_limits.get(api_key, {}).get(current_minute, 0) | |
stats["rate_limit"] = self.max_requests_per_minute | |
return stats | |
return None | |
def generate_new_api_key(self, user: str) -> str: | |
"""Generate a new secure API key""" | |
api_key = secrets.token_urlsafe(32) | |
self.api_keys[api_key] = { | |
"user": user, | |
"created": datetime.now(), | |
"last_used": None, | |
"request_count": 0 | |
} | |
return api_key | |
def revoke_api_key(self, api_key: str) -> bool: | |
"""Revoke an API key""" | |
if api_key in self.api_keys: | |
del self.api_keys[api_key] | |
if api_key in self.rate_limits: | |
del self.rate_limits[api_key] | |
return True | |
return False | |
def list_api_keys(self) -> Dict: | |
"""List all API keys with their stats (without revealing the keys)""" | |
result = {} | |
for api_key, info in self.api_keys.items(): | |
# Hash the API key for identification without revealing it | |
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:8] | |
result[key_hash] = { | |
"user": info["user"], | |
"created": info["created"].isoformat(), | |
"last_used": info["last_used"].isoformat() if info["last_used"] else None, | |
"request_count": info["request_count"] | |
} | |
return result | |
# Global instance | |
api_key_manager = APIKeyManager() | |
def get_api_key_manager() -> APIKeyManager: | |
"""Get the global API key manager instance""" | |
return api_key_manager | |