# | |
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org> | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import time # Import the time module to work with timestamps for session expiration checks | |
import uuid # Import uuid module to generate unique session identifiers | |
from gradio_client import Client # Import Client from gradio_client to interact with the AI model API | |
from typing import Dict, Tuple, Optional, Any # Import type hints for better code clarity and validation | |
from config import EXPIRE # Import the EXPIRE constant which defines session timeout duration | |
# Dictionary to store active user sessions | |
# Key: session_id (string) uniquely identifying each session | |
# Value: Tuple containing: | |
# - last_update_timestamp (float): the last time this session was accessed or updated | |
# - session_data_dict (dict): holds session-specific data including: | |
# - "model": the AI model name currently used in this session | |
# - "history": a list that tracks the conversation history (inputs and responses) | |
# - "client": the Gradio Client instance associated with this session for API calls | |
session_store: Dict[str, Tuple[float, Dict[str, Any]]] = {} | |
def cleanup_expired_sessions(): | |
""" | |
Iterate through all stored sessions and remove those that have been inactive | |
for longer than the configured EXPIRE duration. This function helps prevent | |
memory leaks and resource wastage by closing Gradio clients and deleting | |
session data for sessions no longer in use. | |
""" | |
now = time.time() # Get the current time in seconds since epoch | |
# Identify all sessions where the time since last update exceeds the expiration limit | |
expired_sessions = [ | |
sid for sid, (last_update, _) in session_store.items() | |
if now - last_update > EXPIRE | |
] | |
# For each expired session, safely close the associated Gradio client and remove session data | |
for sid in expired_sessions: | |
_, data = session_store[sid] # Extract session data dictionary | |
client = data.get("client") # Retrieve the Gradio client instance if it exists | |
if client: | |
try: | |
client.close() # Attempt to close the client connection to release resources | |
except Exception: | |
# Suppress any exceptions during client close to ensure cleanup continues smoothly | |
pass | |
del session_store[sid] # Remove the session entry from the session store dictionary | |
def create_client_for_model(model: str) -> Client: | |
""" | |
Instantiate a new Gradio Client connected to the AI model API and configure it | |
to use the specified model. This client will be used to send requests and receive | |
responses for the given AI model in a session. | |
Parameters: | |
- model (str): The name of the AI model to initialize the client with. | |
Returns: | |
- Client: A configured Gradio Client instance ready to interact with the model. | |
""" | |
client = Client("hadadrjt/ai") # Create a new Gradio Client pointing to the AI service endpoint | |
# Call the /change_model API on the client to switch to the requested AI model | |
client.predict(new=model, api_name="/change_model") | |
return client # Return the configured client instance | |
def get_or_create_session(session_id: Optional[str], model: str) -> str: | |
""" | |
Retrieve an existing session by its session ID or create a new session if none exists. | |
This function also performs cleanup of expired sessions before proceeding to ensure | |
efficient resource management. | |
If the requested session exists but uses a different model than specified, the session's | |
client is replaced with a new one configured for the new model. | |
Parameters: | |
- session_id (Optional[str]): The unique identifier of the session to retrieve. If None or | |
invalid, a new session will be created. | |
- model (str): The AI model to be used for this session. | |
Returns: | |
- str: The session ID of the active or newly created session. | |
""" | |
cleanup_expired_sessions() # Remove any sessions that have timed out before proceeding | |
# Check if the provided session_id is valid and exists in the session store | |
if not session_id or session_id not in session_store: | |
# Generate a new unique session ID using UUID4 | |
session_id = str(uuid.uuid4()) | |
# Create a new Gradio client configured for the requested model | |
client = create_client_for_model(model) | |
# Store the new session with current timestamp, model name, empty history, and client instance | |
session_store[session_id] = (time.time(), { | |
"model": model, | |
"history": [], | |
"client": client | |
}) | |
else: | |
# Existing session found, retrieve its last update time and data dictionary | |
last_update, data = session_store[session_id] | |
# Check if the model requested differs from the one currently associated with the session | |
if data["model"] != model: | |
# Close the old client to release resources before switching models | |
old_client = data.get("client") | |
if old_client: | |
try: | |
old_client.close() | |
except Exception: | |
# Ignore any exceptions during client close to avoid interrupting flow | |
pass | |
# Create a new client configured for the new model | |
new_client = create_client_for_model(model) | |
# Update session data with the new model and client instance | |
data["model"] = model | |
data["client"] = new_client | |
# Update the session store with the new timestamp and updated data dictionary | |
session_store[session_id] = (time.time(), data) | |
else: | |
# Model has not changed, just update the last access time to keep session active | |
session_store[session_id] = (time.time(), data) | |
return session_id # Return the active or newly created session ID | |