Chatbot_capstone / groq_api.py
Ezhil
repo init
601d457
import os
import json
import re
from groq import Groq
from dotenv import load_dotenv
import httpx
from tools import tools
from utils import execute_sql_query
# Load environment variables
load_dotenv()
# Initialize Groq client
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"), http_client=httpx.Client())
print(os.getenv("GROQ_API_KEY"))
def chat_with_groq(user_input):
"""
Processes user input using Groq API and executes SQL queries on Supabase when needed.
Args:
user_input (str): The user's query.
Returns:
str: Response from the chatbot.
"""
try:
# Extract the number from "top X songs" or "give top X songs" if present
limit = 10 # Default limit
match = re.search(r'(?:top|give top) (\d+)', user_input.lower())
if match:
limit = int(match.group(1))
print("hered ")
response = groq_client.chat.completions.create(
model="llama3-8b-8192",
messages=[
{"role": "system", "content": (
"You are a helpful assistant that can query a Supabase PostgreSQL database using SQL. "
"Use the execute_sql_query function only when the user explicitly asks for data from the database (e.g., 'give me songs', 'find songs', 'top songs', 'give top songs'). "
"For greetings like 'hi' or 'hello', respond with a simple greeting like 'Hello! How can I help you?' without querying the database. "
"The database has a 'songs' table with columns: \"Track Name\", \"Artist Name(s)\", \"Valence\", \"Popularity\", etc. "
"Always use quoted column names to handle case sensitivity and special characters (e.g., \"Track Name\" with quotes). "
"Ensure there is a space after each quoted column name in the SELECT clause and a space before the FROM keyword (e.g., SELECT \"Track Name\", \"Artist Name(s)\" FROM with spaces). "
"The \"Artist Name(s)\" column may contain multiple artists as a comma-separated string, so use ILIKE for partial matching (e.g., \"Artist Name(s)\" ILIKE '%artist_name%'). "
"For queries like 'top X songs' or 'give top X songs', extract the number X (default to 10 if not specified) and use it in the LIMIT clause. "
"For generic song queries, use: SELECT \"Track Name\", \"Artist Name(s)\" FROM songs LIMIT X. "
"If the user specifies a sorting criterion (e.g., 'top 10 songs by popularity'), sort by the appropriate column (e.g., ORDER BY \"Popularity\" DESC). "
"Always return SELECT \"Track Name\", \"Artist Name(s)\" in the query, not SELECT *. "
"Generate complete and valid JSON and SQL queries, ensuring proper escaping of quotes, correct spacing, and using ASCII characters for operators (e.g., use < and >, not \u003c or \u003e)."
)},
{"role": "user", "content": user_input}
],
tools=tools,
tool_choice="auto",
max_tokens=4096
)
print(f"Full response: {response}") # Debug the entire response
choice = response.choices[0]
tool_calls = getattr(choice.message, 'tool_calls', None)
message_content = getattr(choice.message, 'content', None)
# Handle /tool-use block in content if tool_calls is None
if not tool_calls and message_content:
tool_use_match = re.search(
r'<tool-use>\n(.*)\n</tool-use>', message_content, re.DOTALL)
if tool_use_match:
tool_use_content = tool_use_match.group(1)
try:
tool_use_data = json.loads(tool_use_content)
tool_calls = tool_use_data.get("tool_calls", [])
# Convert dict to object for consistency with tool_calls structure
class ToolCall:
def __init__(self, d):
self.__dict__ = d
self.function = type('Function', (), {
'name': d['function']['name'], 'arguments': d['function']['arguments']})()
tool_calls = [ToolCall(tc) for tc in tool_calls]
except json.JSONDecodeError as e:
print(f"Failed to parse /tool-use block: {e}")
tool_calls = []
if tool_calls:
for tool_call in tool_calls:
if tool_call.function.name == "execute_sql_query":
try:
# Extract the arguments string
arguments_str = tool_call.function.arguments
# Debug output
print(f"Raw arguments_str: {arguments_str}")
# Replace Unicode characters with their ASCII equivalents
arguments_str = arguments_str.replace(
'\u003e', '>').replace('\u003c', '<')
# Extract the sql_query value using a robust regex
match = re.search(
r'"sql_query":"((?:[^"\\]|\\.)*)"', arguments_str)
if match:
sql_query = match.group(1)
# Clean inner escaped quotes
sql_query = sql_query.replace('\\"', '"')
# Remove any trailing semicolon
sql_query = sql_query.rstrip(';')
else:
sql_query = ""
print("Failed to extract sql_query from arguments_str")
except Exception as e:
return f"⚠️ Error parsing tool call arguments: {str(e)} - Raw JSON: {arguments_str}"
if not sql_query:
return "⚠️ No SQL query provided."
# Debug: Print the extracted SQL query
print(f"Extracted SQL query: {sql_query}")
# Clean the SQL query to remove any remaining escape issues
sql_query = sql_query.replace('\\"', '"')
# Replace Unicode characters (redundant but ensures all cases are covered)
sql_query = sql_query.replace(
'\u003e', '>').replace('\u003c', '<')
# Fix regex pattern (if any regex is used in the query)
sql_query = sql_query.replace('^[0-9.]+$$', '^[0-9.]+$')
# Debug: Print query before cleaning
print(f"SQL query before cleaning: {sql_query}")
# Ensure proper spacing in the SELECT clause
# Add space after comma between quoted columns
sql_query = re.sub(
r'("[^"]+")\s*,\s*("[^"]+")', r'\1, \2', sql_query)
# Ensure space before FROM (case-insensitive match for FROM)
# Add space before FROM
sql_query = re.sub(
r'("[^"]+")(?i)(FROM)', r'\1 FROM', sql_query)
# Debug: Print query after cleaning
print(f"Cleaned SQL query: {sql_query}")
# Basic SQL syntax check
if not sql_query.strip().upper().startswith("SELECT"):
return f"⚠️ Invalid SQL query: {sql_query}"
# Debug: Print final query before execution
print(f"Final SQL query before execution: {sql_query}")
# Execute the SQL query
print(f"Executing SQL Query: {sql_query}")
result = execute_sql_query(sql_query)
if isinstance(result, list):
if result:
formatted_result = f"Top {min(len(result), limit)} Songs:\n"
# Limit to requested or available songs
for i, row in enumerate(result[:limit], 1):
track_name = row.get(
"Track Name", "Unknown Track")
artist_names = row.get(
"Artist Name(s)", "Unknown Artist")
formatted_result += f"{i}. {track_name} by {artist_names}\n"
return formatted_result.strip()
else:
return "🔍 No results found for the query."
else:
return result # Error message from execute_sql_query
# Fallback for no tool calls (e.g., greetings)
if message_content and not tool_calls:
# Check if content is a /tool-use block with empty tool_calls
if '<tool-use>' in message_content and '"tool_calls": []' in message_content:
return "Hello! How can I help you?"
return message_content.strip()
else:
return "I'm sorry, I couldn't process your request. (No message content or tool calls found)"
except Exception as e:
print(e)
return f"Error: {str(e)}"