Spaces:
Sleeping
Sleeping
import os | |
import json | |
import re | |
from groq import Groq | |
from dotenv import load_dotenv | |
import httpx | |
from tools import tools | |
from utils import execute_sql_query | |
# Load environment variables | |
load_dotenv() | |
# Initialize Groq client | |
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"), http_client=httpx.Client()) | |
print(os.getenv("GROQ_API_KEY")) | |
def chat_with_groq(user_input): | |
""" | |
Processes user input using Groq API and executes SQL queries on Supabase when needed. | |
Args: | |
user_input (str): The user's query. | |
Returns: | |
str: Response from the chatbot. | |
""" | |
try: | |
# Extract the number from "top X songs" or "give top X songs" if present | |
limit = 10 # Default limit | |
match = re.search(r'(?:top|give top) (\d+)', user_input.lower()) | |
if match: | |
limit = int(match.group(1)) | |
print("hered ") | |
response = groq_client.chat.completions.create( | |
model="llama3-8b-8192", | |
messages=[ | |
{"role": "system", "content": ( | |
"You are a helpful assistant that can query a Supabase PostgreSQL database using SQL. " | |
"Use the execute_sql_query function only when the user explicitly asks for data from the database (e.g., 'give me songs', 'find songs', 'top songs', 'give top songs'). " | |
"For greetings like 'hi' or 'hello', respond with a simple greeting like 'Hello! How can I help you?' without querying the database. " | |
"The database has a 'songs' table with columns: \"Track Name\", \"Artist Name(s)\", \"Valence\", \"Popularity\", etc. " | |
"Always use quoted column names to handle case sensitivity and special characters (e.g., \"Track Name\" with quotes). " | |
"Ensure there is a space after each quoted column name in the SELECT clause and a space before the FROM keyword (e.g., SELECT \"Track Name\", \"Artist Name(s)\" FROM with spaces). " | |
"The \"Artist Name(s)\" column may contain multiple artists as a comma-separated string, so use ILIKE for partial matching (e.g., \"Artist Name(s)\" ILIKE '%artist_name%'). " | |
"For queries like 'top X songs' or 'give top X songs', extract the number X (default to 10 if not specified) and use it in the LIMIT clause. " | |
"For generic song queries, use: SELECT \"Track Name\", \"Artist Name(s)\" FROM songs LIMIT X. " | |
"If the user specifies a sorting criterion (e.g., 'top 10 songs by popularity'), sort by the appropriate column (e.g., ORDER BY \"Popularity\" DESC). " | |
"Always return SELECT \"Track Name\", \"Artist Name(s)\" in the query, not SELECT *. " | |
"Generate complete and valid JSON and SQL queries, ensuring proper escaping of quotes, correct spacing, and using ASCII characters for operators (e.g., use < and >, not \u003c or \u003e)." | |
)}, | |
{"role": "user", "content": user_input} | |
], | |
tools=tools, | |
tool_choice="auto", | |
max_tokens=4096 | |
) | |
print(f"Full response: {response}") # Debug the entire response | |
choice = response.choices[0] | |
tool_calls = getattr(choice.message, 'tool_calls', None) | |
message_content = getattr(choice.message, 'content', None) | |
# Handle /tool-use block in content if tool_calls is None | |
if not tool_calls and message_content: | |
tool_use_match = re.search( | |
r'<tool-use>\n(.*)\n</tool-use>', message_content, re.DOTALL) | |
if tool_use_match: | |
tool_use_content = tool_use_match.group(1) | |
try: | |
tool_use_data = json.loads(tool_use_content) | |
tool_calls = tool_use_data.get("tool_calls", []) | |
# Convert dict to object for consistency with tool_calls structure | |
class ToolCall: | |
def __init__(self, d): | |
self.__dict__ = d | |
self.function = type('Function', (), { | |
'name': d['function']['name'], 'arguments': d['function']['arguments']})() | |
tool_calls = [ToolCall(tc) for tc in tool_calls] | |
except json.JSONDecodeError as e: | |
print(f"Failed to parse /tool-use block: {e}") | |
tool_calls = [] | |
if tool_calls: | |
for tool_call in tool_calls: | |
if tool_call.function.name == "execute_sql_query": | |
try: | |
# Extract the arguments string | |
arguments_str = tool_call.function.arguments | |
# Debug output | |
print(f"Raw arguments_str: {arguments_str}") | |
# Replace Unicode characters with their ASCII equivalents | |
arguments_str = arguments_str.replace( | |
'\u003e', '>').replace('\u003c', '<') | |
# Extract the sql_query value using a robust regex | |
match = re.search( | |
r'"sql_query":"((?:[^"\\]|\\.)*)"', arguments_str) | |
if match: | |
sql_query = match.group(1) | |
# Clean inner escaped quotes | |
sql_query = sql_query.replace('\\"', '"') | |
# Remove any trailing semicolon | |
sql_query = sql_query.rstrip(';') | |
else: | |
sql_query = "" | |
print("Failed to extract sql_query from arguments_str") | |
except Exception as e: | |
return f"⚠️ Error parsing tool call arguments: {str(e)} - Raw JSON: {arguments_str}" | |
if not sql_query: | |
return "⚠️ No SQL query provided." | |
# Debug: Print the extracted SQL query | |
print(f"Extracted SQL query: {sql_query}") | |
# Clean the SQL query to remove any remaining escape issues | |
sql_query = sql_query.replace('\\"', '"') | |
# Replace Unicode characters (redundant but ensures all cases are covered) | |
sql_query = sql_query.replace( | |
'\u003e', '>').replace('\u003c', '<') | |
# Fix regex pattern (if any regex is used in the query) | |
sql_query = sql_query.replace('^[0-9.]+$$', '^[0-9.]+$') | |
# Debug: Print query before cleaning | |
print(f"SQL query before cleaning: {sql_query}") | |
# Ensure proper spacing in the SELECT clause | |
# Add space after comma between quoted columns | |
sql_query = re.sub( | |
r'("[^"]+")\s*,\s*("[^"]+")', r'\1, \2', sql_query) | |
# Ensure space before FROM (case-insensitive match for FROM) | |
# Add space before FROM | |
sql_query = re.sub( | |
r'("[^"]+")(?i)(FROM)', r'\1 FROM', sql_query) | |
# Debug: Print query after cleaning | |
print(f"Cleaned SQL query: {sql_query}") | |
# Basic SQL syntax check | |
if not sql_query.strip().upper().startswith("SELECT"): | |
return f"⚠️ Invalid SQL query: {sql_query}" | |
# Debug: Print final query before execution | |
print(f"Final SQL query before execution: {sql_query}") | |
# Execute the SQL query | |
print(f"Executing SQL Query: {sql_query}") | |
result = execute_sql_query(sql_query) | |
if isinstance(result, list): | |
if result: | |
formatted_result = f"Top {min(len(result), limit)} Songs:\n" | |
# Limit to requested or available songs | |
for i, row in enumerate(result[:limit], 1): | |
track_name = row.get( | |
"Track Name", "Unknown Track") | |
artist_names = row.get( | |
"Artist Name(s)", "Unknown Artist") | |
formatted_result += f"{i}. {track_name} by {artist_names}\n" | |
return formatted_result.strip() | |
else: | |
return "🔍 No results found for the query." | |
else: | |
return result # Error message from execute_sql_query | |
# Fallback for no tool calls (e.g., greetings) | |
if message_content and not tool_calls: | |
# Check if content is a /tool-use block with empty tool_calls | |
if '<tool-use>' in message_content and '"tool_calls": []' in message_content: | |
return "Hello! How can I help you?" | |
return message_content.strip() | |
else: | |
return "I'm sorry, I couldn't process your request. (No message content or tool calls found)" | |
except Exception as e: | |
print(e) | |
return f"Error: {str(e)}" |