Spaces:
Running
Running
import asyncio | |
import time | |
import uuid | |
import logging | |
from fastapi import FastAPI, HTTPException, Query | |
from pydantic import BaseModel | |
from typing import Optional, Dict, List | |
from PyCharacterAI import get_client | |
from PyCharacterAI.exceptions import SessionClosedError, RequestError | |
import uvicorn | |
import os | |
app = FastAPI() | |
DEFAULT_TOKEN = os.getenv("token") | |
DEFAULT_CHARACTER_ID = "smtV3Vyez6ODkwS8BErmBAdgGNj-1XWU73wIFVOY1hQ" | |
DEFAULT_VOICE_ID = "974fea59-7c26-411b-ae0d-64ff2c4e9666" | |
MAX_RETRIES = 5 | |
RETRY_DELAY = 1.0 | |
BACKOFF_MULTIPLIER = 2.0 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
client = None | |
chat_locks = {} | |
new_chat_lock = asyncio.Lock() | |
async def startup_event(): | |
global client | |
client = await get_client(token=DEFAULT_TOKEN) | |
async def shutdown_event(): | |
if client: | |
await client.close_session() | |
async def reinitialize_client(): | |
"""Helper function to reinitialize the client if the session is closed.""" | |
global client | |
try: | |
if client: | |
await client.close_session() | |
except: | |
pass | |
client = await get_client(token=DEFAULT_TOKEN) | |
def get_chat_lock(chat_id): | |
"""Get a lock specific to the chat_id or create a new one if it doesn't exist.""" | |
if chat_id not in chat_locks: | |
chat_locks[chat_id] = asyncio.Lock() | |
return chat_locks[chat_id] | |
def is_retryable_error(exception): | |
""" | |
Determine if an error should be retried. | |
Args: | |
exception: The exception to check | |
Returns: | |
bool: True if the error should be retried, False otherwise | |
""" | |
error_message = str(exception).lower() | |
retryable_patterns = [ | |
"maybe your token is invalid", | |
"session closed", | |
"connection", | |
"timeout", | |
"request failed", | |
"network", | |
"server error", | |
"internal server error", | |
"bad gateway", | |
"service unavailable", | |
"gateway timeout", | |
"rate limit", | |
"too many requests", | |
"temporary", | |
"temporarily unavailable", | |
"request timeout", | |
"read timeout", | |
] | |
retryable_exceptions = ( | |
RequestError, | |
SessionClosedError, | |
ConnectionError, | |
TimeoutError, | |
asyncio.TimeoutError, | |
) | |
if isinstance(exception, retryable_exceptions): | |
return True | |
for pattern in retryable_patterns: | |
if pattern in error_message: | |
return True | |
return False | |
async def retry_with_backoff(func, *args, max_retries=MAX_RETRIES, base_delay=RETRY_DELAY, **kwargs): | |
""" | |
Retry a function with exponential backoff. | |
Args: | |
func: The async function to retry | |
*args: Arguments to pass to the function | |
max_retries: Maximum number of retry attempts | |
base_delay: Base delay between retries in seconds | |
**kwargs: Keyword arguments to pass to the function | |
Returns: | |
The result of the function call | |
Raises: | |
The last exception encountered if all retries fail | |
""" | |
last_exception = None | |
for attempt in range(max_retries + 1): | |
try: | |
return await func(*args, **kwargs) | |
except Exception as e: | |
last_exception = e | |
if not is_retryable_error(e): | |
logger.error(f"Non-retryable error occurred: {str(e)}") | |
raise e | |
if attempt == max_retries: | |
logger.error(f"All {max_retries + 1} attempts failed. Last error: {str(e)}") | |
raise e | |
delay = base_delay * (BACKOFF_MULTIPLIER ** attempt) | |
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay:.2f} seconds...") | |
if isinstance(e, SessionClosedError) or "token" in str(e).lower(): | |
logger.info("Session/token issue detected, reinitializing client...") | |
await reinitialize_client() | |
await asyncio.sleep(delay) | |
if last_exception: | |
raise last_exception | |
async def search_characters( | |
query: str = Query(..., description="Character name or keyword to search for"), | |
token: Optional[str] = Query(None, description="API token for authentication."), | |
limit: int = Query(10, description="Maximum number of results to return") | |
): | |
""" | |
Search for characters by name or keyword. | |
Returns a list of characters matching the search criteria. | |
""" | |
global client | |
token = token or DEFAULT_TOKEN | |
try: | |
if client is None: | |
await reinitialize_client() | |
async def search_operation(): | |
return await client.character.search_characters(query) | |
characters = await retry_with_backoff(search_operation) | |
results = [] | |
for char in characters[:limit]: | |
character_info = { | |
"id": char.character_id, | |
"name": char.name, | |
"greeting": char.greeting, | |
"description": char.description, | |
"avatar_url": char.avatar, | |
} | |
results.append(character_info) | |
return {"query": query, "results": results, "total_found": len(characters), "returned": len(results)} | |
except SessionClosedError: | |
await reinitialize_client() | |
raise HTTPException( | |
status_code=500, | |
detail="Session was closed after retries. Please try your request again." | |
) | |
except RequestError as e: | |
raise HTTPException(status_code=500, detail=f"Character AI API error after retries: {str(e)}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
async def send_message( | |
message: str, | |
token: Optional[str] = Query(None, description="API token for authentication."), | |
character_id: str = Query(DEFAULT_CHARACTER_ID, description="Character ID for the chat session."), | |
chat_id: Optional[str] = Query(None, description="ID of the existing chat session, if available."), | |
voice_id: str = Query(DEFAULT_VOICE_ID, description="Voice ID for generating speech."), | |
voice: bool = Query(True, description="Set to true to generate voice, false to skip voice generation.") | |
): | |
""" | |
Send a message to the character. If no chat_id is provided, initialize a new chat session. | |
Optionally generate voice for the response. | |
""" | |
global client | |
token = token or DEFAULT_TOKEN | |
try: | |
if client is None: | |
await reinitialize_client() | |
greeting_text = None | |
if not chat_id: | |
async with new_chat_lock: | |
try: | |
async def create_chat_operation(): | |
return await client.chat.create_chat(character_id) | |
chat, greeting_message = await retry_with_backoff(create_chat_operation) | |
chat_id = chat.chat_id | |
chat_locks[chat_id] = asyncio.Lock() | |
greeting_text = { | |
"author": greeting_message.author_name, | |
"text": greeting_message.get_primary_candidate().text | |
} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to create new chat after retries: {str(e)}" | |
) | |
chat_lock = get_chat_lock(chat_id) | |
async with chat_lock: | |
async def send_message_operation(): | |
return await client.chat.send_message(character_id, chat_id, message) | |
answer = await retry_with_backoff(send_message_operation) | |
response_text = answer.get_primary_candidate().text | |
speech_url = None | |
if voice: | |
try: | |
async def generate_speech_operation(): | |
return await client.utils.generate_speech( | |
chat_id, | |
answer.turn_id, | |
answer.get_primary_candidate().candidate_id, | |
voice_id, | |
return_url=True | |
) | |
speech_url = await retry_with_backoff( | |
generate_speech_operation, | |
max_retries=5, | |
base_delay=0.5 | |
) | |
except Exception as e: | |
logger.warning(f"Voice generation failed after retries: {e}") | |
response_data = { | |
"chat_id": chat_id, | |
"author": answer.author_name, | |
"response": response_text, | |
"voice_url": speech_url if voice else None | |
} | |
if greeting_text: | |
response_data["greeting_message"] = greeting_text | |
return response_data | |
except SessionClosedError: | |
await reinitialize_client() | |
raise HTTPException( | |
status_code=500, | |
detail="Session was closed after retries. Please try your request again." | |
) | |
except RequestError as e: | |
raise HTTPException(status_code=500, detail=f"Character AI API error after retries: {str(e)}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
async def start_cleanup_task(): | |
async def cleanup_locks(): | |
while True: | |
await asyncio.sleep(3600) | |
current_time = time.time() | |
locks_to_remove = [] | |
for lock_id in list(chat_locks.keys()): | |
lock = chat_locks.get(lock_id) | |
if lock and not lock.locked(): | |
pass | |
asyncio.create_task(cleanup_locks()) | |
def root(): | |
return {"message": "Welcome to the Character AI API with FastAPI!"} | |
def health_check(): | |
return {"status": "healthy", "client_initialized": client is not None} | |
if __name__ == "__main__": | |
uvicorn.run("main:app", host="0.0.0.0", port=7860, workers=8, timeout_keep_alive=60000) |