cai / main.py
tachibanaa710's picture
Update main.py
79da1f3 verified
import asyncio
import time
import uuid
import logging
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from typing import Optional, Dict, List
from PyCharacterAI import get_client
from PyCharacterAI.exceptions import SessionClosedError, RequestError
import uvicorn
import os
app = FastAPI()
DEFAULT_TOKEN = os.getenv("token")
DEFAULT_CHARACTER_ID = "smtV3Vyez6ODkwS8BErmBAdgGNj-1XWU73wIFVOY1hQ"
DEFAULT_VOICE_ID = "974fea59-7c26-411b-ae0d-64ff2c4e9666"
MAX_RETRIES = 5
RETRY_DELAY = 1.0
BACKOFF_MULTIPLIER = 2.0
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
client = None
chat_locks = {}
new_chat_lock = asyncio.Lock()
@app.on_event("startup")
async def startup_event():
global client
client = await get_client(token=DEFAULT_TOKEN)
@app.on_event("shutdown")
async def shutdown_event():
if client:
await client.close_session()
async def reinitialize_client():
"""Helper function to reinitialize the client if the session is closed."""
global client
try:
if client:
await client.close_session()
except:
pass
client = await get_client(token=DEFAULT_TOKEN)
def get_chat_lock(chat_id):
"""Get a lock specific to the chat_id or create a new one if it doesn't exist."""
if chat_id not in chat_locks:
chat_locks[chat_id] = asyncio.Lock()
return chat_locks[chat_id]
def is_retryable_error(exception):
"""
Determine if an error should be retried.
Args:
exception: The exception to check
Returns:
bool: True if the error should be retried, False otherwise
"""
error_message = str(exception).lower()
retryable_patterns = [
"maybe your token is invalid",
"session closed",
"connection",
"timeout",
"request failed",
"network",
"server error",
"internal server error",
"bad gateway",
"service unavailable",
"gateway timeout",
"rate limit",
"too many requests",
"temporary",
"temporarily unavailable",
"request timeout",
"read timeout",
]
retryable_exceptions = (
RequestError,
SessionClosedError,
ConnectionError,
TimeoutError,
asyncio.TimeoutError,
)
if isinstance(exception, retryable_exceptions):
return True
for pattern in retryable_patterns:
if pattern in error_message:
return True
return False
async def retry_with_backoff(func, *args, max_retries=MAX_RETRIES, base_delay=RETRY_DELAY, **kwargs):
"""
Retry a function with exponential backoff.
Args:
func: The async function to retry
*args: Arguments to pass to the function
max_retries: Maximum number of retry attempts
base_delay: Base delay between retries in seconds
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function call
Raises:
The last exception encountered if all retries fail
"""
last_exception = None
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if not is_retryable_error(e):
logger.error(f"Non-retryable error occurred: {str(e)}")
raise e
if attempt == max_retries:
logger.error(f"All {max_retries + 1} attempts failed. Last error: {str(e)}")
raise e
delay = base_delay * (BACKOFF_MULTIPLIER ** attempt)
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay:.2f} seconds...")
if isinstance(e, SessionClosedError) or "token" in str(e).lower():
logger.info("Session/token issue detected, reinitializing client...")
await reinitialize_client()
await asyncio.sleep(delay)
if last_exception:
raise last_exception
@app.get("/search_characters")
async def search_characters(
query: str = Query(..., description="Character name or keyword to search for"),
token: Optional[str] = Query(None, description="API token for authentication."),
limit: int = Query(10, description="Maximum number of results to return")
):
"""
Search for characters by name or keyword.
Returns a list of characters matching the search criteria.
"""
global client
token = token or DEFAULT_TOKEN
try:
if client is None:
await reinitialize_client()
async def search_operation():
return await client.character.search_characters(query)
characters = await retry_with_backoff(search_operation)
results = []
for char in characters[:limit]:
character_info = {
"id": char.character_id,
"name": char.name,
"greeting": char.greeting,
"description": char.description,
"avatar_url": char.avatar,
}
results.append(character_info)
return {"query": query, "results": results, "total_found": len(characters), "returned": len(results)}
except SessionClosedError:
await reinitialize_client()
raise HTTPException(
status_code=500,
detail="Session was closed after retries. Please try your request again."
)
except RequestError as e:
raise HTTPException(status_code=500, detail=f"Character AI API error after retries: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@app.get("/send_message")
async def send_message(
message: str,
token: Optional[str] = Query(None, description="API token for authentication."),
character_id: str = Query(DEFAULT_CHARACTER_ID, description="Character ID for the chat session."),
chat_id: Optional[str] = Query(None, description="ID of the existing chat session, if available."),
voice_id: str = Query(DEFAULT_VOICE_ID, description="Voice ID for generating speech."),
voice: bool = Query(True, description="Set to true to generate voice, false to skip voice generation.")
):
"""
Send a message to the character. If no chat_id is provided, initialize a new chat session.
Optionally generate voice for the response.
"""
global client
token = token or DEFAULT_TOKEN
try:
if client is None:
await reinitialize_client()
greeting_text = None
if not chat_id:
async with new_chat_lock:
try:
async def create_chat_operation():
return await client.chat.create_chat(character_id)
chat, greeting_message = await retry_with_backoff(create_chat_operation)
chat_id = chat.chat_id
chat_locks[chat_id] = asyncio.Lock()
greeting_text = {
"author": greeting_message.author_name,
"text": greeting_message.get_primary_candidate().text
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to create new chat after retries: {str(e)}"
)
chat_lock = get_chat_lock(chat_id)
async with chat_lock:
async def send_message_operation():
return await client.chat.send_message(character_id, chat_id, message)
answer = await retry_with_backoff(send_message_operation)
response_text = answer.get_primary_candidate().text
speech_url = None
if voice:
try:
async def generate_speech_operation():
return await client.utils.generate_speech(
chat_id,
answer.turn_id,
answer.get_primary_candidate().candidate_id,
voice_id,
return_url=True
)
speech_url = await retry_with_backoff(
generate_speech_operation,
max_retries=5,
base_delay=0.5
)
except Exception as e:
logger.warning(f"Voice generation failed after retries: {e}")
response_data = {
"chat_id": chat_id,
"author": answer.author_name,
"response": response_text,
"voice_url": speech_url if voice else None
}
if greeting_text:
response_data["greeting_message"] = greeting_text
return response_data
except SessionClosedError:
await reinitialize_client()
raise HTTPException(
status_code=500,
detail="Session was closed after retries. Please try your request again."
)
except RequestError as e:
raise HTTPException(status_code=500, detail=f"Character AI API error after retries: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@app.on_event("startup")
async def start_cleanup_task():
async def cleanup_locks():
while True:
await asyncio.sleep(3600)
current_time = time.time()
locks_to_remove = []
for lock_id in list(chat_locks.keys()):
lock = chat_locks.get(lock_id)
if lock and not lock.locked():
pass
asyncio.create_task(cleanup_locks())
@app.get("/")
def root():
return {"message": "Welcome to the Character AI API with FastAPI!"}
@app.get("/health")
def health_check():
return {"status": "healthy", "client_initialized": client is not None}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860, workers=8, timeout_keep_alive=60000)