Spaces:
Sleeping
Sleeping
import pandas as pd | |
import openai | |
import os | |
import io | |
# Load tool data from CSV | |
import numpy as np | |
def sanitize_json(obj): | |
"""Recursively convert pandas/numpy types in obj to native Python types.""" | |
if isinstance(obj, dict): | |
return {k: sanitize_json(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [sanitize_json(v) for v in obj] | |
elif isinstance(obj, (np.generic, np.bool_)): | |
return obj.item() | |
elif hasattr(obj, 'item') and callable(obj.item): | |
try: | |
return obj.item() | |
except Exception: | |
return str(obj) | |
else: | |
return obj | |
def load_tool_data(csv_path=None): | |
if csv_path is None: | |
csv_path = os.path.join(os.path.dirname(__file__), 'ai_tools.csv') | |
# Read CSV while skipping lines that start with # | |
with open(csv_path, 'r', encoding='utf-8') as f: | |
lines = [line for line in f if not line.strip().startswith('#')] | |
df = pd.read_csv(io.StringIO(''.join(lines))) | |
# Clean up all string columns (strip whitespace) | |
for col in ['ToolName', 'MediaType', 'Cost', 'APIEndpoint', 'Strengths']: | |
if col in df.columns: | |
df[col] = df[col].astype(str).str.strip() | |
# Convert booleans (IsOpenSource, CommercialUse) from string to bool if needed | |
for col in ['IsOpenSource', 'CommercialUse']: | |
if col in df.columns: | |
df[col] = df[col].astype(str).str.strip().str.lower().map({'true': True, 'false': False}) | |
# Remove any rows where ToolName is empty or starts with # | |
df = df[df['ToolName'].str.strip() != ''] | |
df = df[~df['ToolName'].str.strip().str.startswith('#')] | |
# Log the available tools for debugging | |
print(f"[DEBUG] Available tools: {df['ToolName'].tolist()}") | |
return df | |
return df | |
def universal_router_agent(prompt, tool_data_df, user_criteria=None, media_type=None, user_id=None, debug=False): | |
""" | |
Args: | |
prompt (str): The user's creative prompt | |
tool_data_df (pd.DataFrame): DataFrame of available AI tools | |
user_criteria (dict): Optional user criteria (e.g., {"Cost": "Free", "IsOpenSource": True}) | |
media_type (str): Optional media type ("image", "video", "music") | |
user_id (str): Optional user id for tracking | |
Returns: | |
dict: Details of the selected tool | |
""" | |
# Step 1: Determine media type if not provided | |
if not media_type and user_criteria and 'MediaType' in user_criteria: | |
media_type = user_criteria['MediaType'] | |
if not media_type: | |
# Optionally, you could use GPT to classify the prompt into a media type here | |
raise ValueError("Media type must be specified.") | |
# Step 2: Filter tools by media type | |
filtered_tools = tool_data_df[tool_data_df['MediaType'].astype(str).str.lower() == media_type.lower()] | |
print(f"[DEBUG] Tools after media type filter: {filtered_tools['ToolName'].tolist()}") | |
# Step 3: Apply user criteria filters | |
quality_for_sort = None | |
if user_criteria: | |
BOOL_COLUMNS = ['IsOpenSource', 'CommercialUse'] | |
# Check for Quality in user_criteria | |
if 'Quality' in user_criteria: | |
try: | |
quality_for_sort = int(user_criteria['Quality']) | |
except Exception: | |
quality_for_sort = None | |
# Remove Quality from criteria so it doesn't confuse GPT prompt | |
user_criteria = {k: v for k, v in user_criteria.items() if k != 'Quality'} | |
for key, value in user_criteria.items(): | |
if key == 'MediaType': | |
continue | |
if key in filtered_tools.columns: | |
col = filtered_tools[key] | |
# Handle boolean columns robustly | |
if key in BOOL_COLUMNS: | |
# Accept True/False (bool) or 'true'/'false' (str, any case) | |
if isinstance(value, str): | |
val_bool = value.strip().lower() == 'true' | |
else: | |
val_bool = bool(value) | |
filtered_tools = filtered_tools[col == val_bool] | |
elif key == 'Cost': | |
# Handle cost criteria more flexibly | |
value = str(value).lower() | |
filtered_tools = filtered_tools[col.astype(str).str.lower().str.contains(value)] | |
elif (key == 'ToolName' or key == 'HostingPlatform') and isinstance(value, list): | |
# Handle ToolName and HostingPlatform as list by checking if any of the values match | |
filtered_tools = filtered_tools[col.astype(str).str.lower().isin([str(v).lower() for v in value])] | |
else: | |
# Compare as lowercase strings for string columns | |
filtered_tools = filtered_tools[col.astype(str).str.lower() == str(value).lower()] | |
print(f"[DEBUG] Tools after criteria filter: {filtered_tools['ToolName'].tolist()}") | |
# Step 3b: If Quality >= 4, sort by ELO descending | |
if quality_for_sort is not None and quality_for_sort >= 4: | |
if 'ELO' in filtered_tools.columns: | |
filtered_tools = filtered_tools.sort_values('ELO', ascending=False) | |
print(f"[DEBUG] Sorted by ELO due to high Quality constraint.") | |
# If Quality == 5, auto-select highest ELO tool and return it | |
if quality_for_sort == 5: | |
if not filtered_tools.empty: | |
top_row = filtered_tools.iloc[0] | |
result = { | |
'ToolName': str(top_row['ToolName']), | |
'Cost': str(top_row['Cost']), | |
'IsOpenSource': bool(top_row['IsOpenSource']), | |
'CommercialUse': bool(top_row['CommercialUse']), | |
'APIEndpoint': str(top_row['APIEndpoint']), | |
'Strengths': str(top_row['Strengths']), | |
'ELO': float(top_row['ELO']) if hasattr(top_row['ELO'], '__float__') else top_row['ELO'] | |
} | |
result = sanitize_json(result) | |
if debug: | |
return (result, {'auto_selected': True, 'filtered_tools': sanitize_json(filtered_tools.to_dict())}) | |
return result | |
debug_info = {} | |
if debug: | |
debug_info['filtered_tools'] = sanitize_json(filtered_tools.to_dict()) | |
if filtered_tools.empty: | |
return (None, debug_info) if debug else None | |
# Step 4: Prepare tool info for GPT selection | |
tool_options = [] | |
for _, row in filtered_tools.iterrows(): | |
tool_options.append(sanitize_json({ | |
'ToolName': row['ToolName'], | |
'Cost': row['Cost'], | |
'IsOpenSource': row['IsOpenSource'], | |
'CommercialUse': row['CommercialUse'], | |
'APIEndpoint': row['APIEndpoint'], | |
'Strengths': row['Strengths'] | |
})) | |
# Step 5: Use GPT-4.1 to select the best tool | |
system_prompt = ( | |
"You are a universal AI router agent. Given a user prompt and a list of available AI tools for a specific media type (image, video, or music), " | |
"select the most appropriate tool based on the user's prompt, tool strengths, and any user criteria. " | |
"Return the name of the tool, cost, expected time (if known), and a quality expectation (high, medium, low). " | |
"If multiple tools are similar, prefer the one with the highest quality. " | |
"Be concise and only output a JSON object with the following fields: ToolName, Cost, ExpectedTime, QualityExpectation, APIEndpoint, Strengths." | |
) | |
# Compose the message for GPT | |
user_message = ( | |
f"User Prompt: {prompt}\n\n" | |
f"Available Tools: {tool_options}\n\n" | |
f"User Criteria: {user_criteria if user_criteria else 'None'}" | |
) | |
response = openai.chat.completions.create( | |
model="gpt-4.1", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_message} | |
], | |
max_tokens=300 | |
) | |
# Expecting a JSON object in the response | |
import json | |
try: | |
content = response.choices[0].message.content | |
# Find the first { ... } block in the response | |
start = content.find('{') | |
end = content.rfind('}') + 1 | |
json_str = content[start:end] | |
result = json.loads(json_str) | |
# Optionally add user_id for tracking | |
if user_id: | |
result['UserId'] = user_id | |
result = sanitize_json(result) | |
return (result, debug_info) if debug else result | |
except Exception as e: | |
err = {"error": f"Could not parse GPT response: {str(e)}", "raw": content} | |
return (err, debug_info) if debug else err | |