Spaces:
Sleeping
Sleeping
File size: 9,362 Bytes
601d457 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import json
import re
from groq import Groq
from dotenv import load_dotenv
import httpx
from tools import tools
from utils import execute_sql_query
# Load environment variables
load_dotenv()
# Initialize Groq client
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"), http_client=httpx.Client())
print(os.getenv("GROQ_API_KEY"))
def chat_with_groq(user_input):
"""
Processes user input using Groq API and executes SQL queries on Supabase when needed.
Args:
user_input (str): The user's query.
Returns:
str: Response from the chatbot.
"""
try:
# Extract the number from "top X songs" or "give top X songs" if present
limit = 10 # Default limit
match = re.search(r'(?:top|give top) (\d+)', user_input.lower())
if match:
limit = int(match.group(1))
print("hered ")
response = groq_client.chat.completions.create(
model="llama3-8b-8192",
messages=[
{"role": "system", "content": (
"You are a helpful assistant that can query a Supabase PostgreSQL database using SQL. "
"Use the execute_sql_query function only when the user explicitly asks for data from the database (e.g., 'give me songs', 'find songs', 'top songs', 'give top songs'). "
"For greetings like 'hi' or 'hello', respond with a simple greeting like 'Hello! How can I help you?' without querying the database. "
"The database has a 'songs' table with columns: \"Track Name\", \"Artist Name(s)\", \"Valence\", \"Popularity\", etc. "
"Always use quoted column names to handle case sensitivity and special characters (e.g., \"Track Name\" with quotes). "
"Ensure there is a space after each quoted column name in the SELECT clause and a space before the FROM keyword (e.g., SELECT \"Track Name\", \"Artist Name(s)\" FROM with spaces). "
"The \"Artist Name(s)\" column may contain multiple artists as a comma-separated string, so use ILIKE for partial matching (e.g., \"Artist Name(s)\" ILIKE '%artist_name%'). "
"For queries like 'top X songs' or 'give top X songs', extract the number X (default to 10 if not specified) and use it in the LIMIT clause. "
"For generic song queries, use: SELECT \"Track Name\", \"Artist Name(s)\" FROM songs LIMIT X. "
"If the user specifies a sorting criterion (e.g., 'top 10 songs by popularity'), sort by the appropriate column (e.g., ORDER BY \"Popularity\" DESC). "
"Always return SELECT \"Track Name\", \"Artist Name(s)\" in the query, not SELECT *. "
"Generate complete and valid JSON and SQL queries, ensuring proper escaping of quotes, correct spacing, and using ASCII characters for operators (e.g., use < and >, not \u003c or \u003e)."
)},
{"role": "user", "content": user_input}
],
tools=tools,
tool_choice="auto",
max_tokens=4096
)
print(f"Full response: {response}") # Debug the entire response
choice = response.choices[0]
tool_calls = getattr(choice.message, 'tool_calls', None)
message_content = getattr(choice.message, 'content', None)
# Handle /tool-use block in content if tool_calls is None
if not tool_calls and message_content:
tool_use_match = re.search(
r'<tool-use>\n(.*)\n</tool-use>', message_content, re.DOTALL)
if tool_use_match:
tool_use_content = tool_use_match.group(1)
try:
tool_use_data = json.loads(tool_use_content)
tool_calls = tool_use_data.get("tool_calls", [])
# Convert dict to object for consistency with tool_calls structure
class ToolCall:
def __init__(self, d):
self.__dict__ = d
self.function = type('Function', (), {
'name': d['function']['name'], 'arguments': d['function']['arguments']})()
tool_calls = [ToolCall(tc) for tc in tool_calls]
except json.JSONDecodeError as e:
print(f"Failed to parse /tool-use block: {e}")
tool_calls = []
if tool_calls:
for tool_call in tool_calls:
if tool_call.function.name == "execute_sql_query":
try:
# Extract the arguments string
arguments_str = tool_call.function.arguments
# Debug output
print(f"Raw arguments_str: {arguments_str}")
# Replace Unicode characters with their ASCII equivalents
arguments_str = arguments_str.replace(
'\u003e', '>').replace('\u003c', '<')
# Extract the sql_query value using a robust regex
match = re.search(
r'"sql_query":"((?:[^"\\]|\\.)*)"', arguments_str)
if match:
sql_query = match.group(1)
# Clean inner escaped quotes
sql_query = sql_query.replace('\\"', '"')
# Remove any trailing semicolon
sql_query = sql_query.rstrip(';')
else:
sql_query = ""
print("Failed to extract sql_query from arguments_str")
except Exception as e:
return f"⚠️ Error parsing tool call arguments: {str(e)} - Raw JSON: {arguments_str}"
if not sql_query:
return "⚠️ No SQL query provided."
# Debug: Print the extracted SQL query
print(f"Extracted SQL query: {sql_query}")
# Clean the SQL query to remove any remaining escape issues
sql_query = sql_query.replace('\\"', '"')
# Replace Unicode characters (redundant but ensures all cases are covered)
sql_query = sql_query.replace(
'\u003e', '>').replace('\u003c', '<')
# Fix regex pattern (if any regex is used in the query)
sql_query = sql_query.replace('^[0-9.]+$$', '^[0-9.]+$')
# Debug: Print query before cleaning
print(f"SQL query before cleaning: {sql_query}")
# Ensure proper spacing in the SELECT clause
# Add space after comma between quoted columns
sql_query = re.sub(
r'("[^"]+")\s*,\s*("[^"]+")', r'\1, \2', sql_query)
# Ensure space before FROM (case-insensitive match for FROM)
# Add space before FROM
sql_query = re.sub(
r'("[^"]+")(?i)(FROM)', r'\1 FROM', sql_query)
# Debug: Print query after cleaning
print(f"Cleaned SQL query: {sql_query}")
# Basic SQL syntax check
if not sql_query.strip().upper().startswith("SELECT"):
return f"⚠️ Invalid SQL query: {sql_query}"
# Debug: Print final query before execution
print(f"Final SQL query before execution: {sql_query}")
# Execute the SQL query
print(f"Executing SQL Query: {sql_query}")
result = execute_sql_query(sql_query)
if isinstance(result, list):
if result:
formatted_result = f"Top {min(len(result), limit)} Songs:\n"
# Limit to requested or available songs
for i, row in enumerate(result[:limit], 1):
track_name = row.get(
"Track Name", "Unknown Track")
artist_names = row.get(
"Artist Name(s)", "Unknown Artist")
formatted_result += f"{i}. {track_name} by {artist_names}\n"
return formatted_result.strip()
else:
return "🔍 No results found for the query."
else:
return result # Error message from execute_sql_query
# Fallback for no tool calls (e.g., greetings)
if message_content and not tool_calls:
# Check if content is a /tool-use block with empty tool_calls
if '<tool-use>' in message_content and '"tool_calls": []' in message_content:
return "Hello! How can I help you?"
return message_content.strip()
else:
return "I'm sorry, I couldn't process your request. (No message content or tool calls found)"
except Exception as e:
print(e)
return f"Error: {str(e)}" |