File size: 9,362 Bytes
601d457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import json
import re
from groq import Groq
from dotenv import load_dotenv
import httpx
from tools import tools
from utils import execute_sql_query

# Load environment variables
load_dotenv()

# Initialize Groq client
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"), http_client=httpx.Client())
print(os.getenv("GROQ_API_KEY"))


def chat_with_groq(user_input):
    """
    Processes user input using Groq API and executes SQL queries on Supabase when needed.

    Args:
        user_input (str): The user's query.

    Returns:
        str: Response from the chatbot.
    """
    try:
        # Extract the number from "top X songs" or "give top X songs" if present
        limit = 10  # Default limit
        match = re.search(r'(?:top|give top) (\d+)', user_input.lower())
        if match:
            limit = int(match.group(1))

        print("hered ")

        response = groq_client.chat.completions.create(
            model="llama3-8b-8192",
            messages=[
                {"role": "system", "content": (
                    "You are a helpful assistant that can query a Supabase PostgreSQL database using SQL. "
                    "Use the execute_sql_query function only when the user explicitly asks for data from the database (e.g., 'give me songs', 'find songs', 'top songs', 'give top songs'). "
                    "For greetings like 'hi' or 'hello', respond with a simple greeting like 'Hello! How can I help you?' without querying the database. "
                    "The database has a 'songs' table with columns: \"Track Name\", \"Artist Name(s)\", \"Valence\", \"Popularity\", etc. "
                    "Always use quoted column names to handle case sensitivity and special characters (e.g., \"Track Name\" with quotes). "
                    "Ensure there is a space after each quoted column name in the SELECT clause and a space before the FROM keyword (e.g., SELECT \"Track Name\", \"Artist Name(s)\" FROM with spaces). "
                    "The \"Artist Name(s)\" column may contain multiple artists as a comma-separated string, so use ILIKE for partial matching (e.g., \"Artist Name(s)\" ILIKE '%artist_name%'). "
                    "For queries like 'top X songs' or 'give top X songs', extract the number X (default to 10 if not specified) and use it in the LIMIT clause. "
                    "For generic song queries, use: SELECT \"Track Name\", \"Artist Name(s)\" FROM songs LIMIT X. "
                    "If the user specifies a sorting criterion (e.g., 'top 10 songs by popularity'), sort by the appropriate column (e.g., ORDER BY \"Popularity\" DESC). "
                    "Always return SELECT \"Track Name\", \"Artist Name(s)\" in the query, not SELECT *. "
                    "Generate complete and valid JSON and SQL queries, ensuring proper escaping of quotes, correct spacing, and using ASCII characters for operators (e.g., use < and >, not \u003c or \u003e)."
                )},
                {"role": "user", "content": user_input}
            ],
            tools=tools,
            tool_choice="auto",
            max_tokens=4096
        )

        print(f"Full response: {response}")  # Debug the entire response
        choice = response.choices[0]
        tool_calls = getattr(choice.message, 'tool_calls', None)
        message_content = getattr(choice.message, 'content', None)

        # Handle /tool-use block in content if tool_calls is None
        if not tool_calls and message_content:
            tool_use_match = re.search(
                r'<tool-use>\n(.*)\n</tool-use>', message_content, re.DOTALL)
            if tool_use_match:
                tool_use_content = tool_use_match.group(1)
                try:
                    tool_use_data = json.loads(tool_use_content)
                    tool_calls = tool_use_data.get("tool_calls", [])
                    # Convert dict to object for consistency with tool_calls structure

                    class ToolCall:
                        def __init__(self, d):
                            self.__dict__ = d
                            self.function = type('Function', (), {
                                                 'name': d['function']['name'], 'arguments': d['function']['arguments']})()
                    tool_calls = [ToolCall(tc) for tc in tool_calls]
                except json.JSONDecodeError as e:
                    print(f"Failed to parse /tool-use block: {e}")
                    tool_calls = []

        if tool_calls:
            for tool_call in tool_calls:
                if tool_call.function.name == "execute_sql_query":
                    try:
                        # Extract the arguments string
                        arguments_str = tool_call.function.arguments
                        # Debug output
                        print(f"Raw arguments_str: {arguments_str}")
                        # Replace Unicode characters with their ASCII equivalents
                        arguments_str = arguments_str.replace(
                            '\u003e', '>').replace('\u003c', '<')
                        # Extract the sql_query value using a robust regex
                        match = re.search(
                            r'"sql_query":"((?:[^"\\]|\\.)*)"', arguments_str)
                        if match:
                            sql_query = match.group(1)
                            # Clean inner escaped quotes
                            sql_query = sql_query.replace('\\"', '"')
                            # Remove any trailing semicolon
                            sql_query = sql_query.rstrip(';')
                        else:
                            sql_query = ""
                            print("Failed to extract sql_query from arguments_str")
                    except Exception as e:
                        return f"⚠️ Error parsing tool call arguments: {str(e)} - Raw JSON: {arguments_str}"

                    if not sql_query:
                        return "⚠️ No SQL query provided."

                    # Debug: Print the extracted SQL query
                    print(f"Extracted SQL query: {sql_query}")

                    # Clean the SQL query to remove any remaining escape issues
                    sql_query = sql_query.replace('\\"', '"')
                    # Replace Unicode characters (redundant but ensures all cases are covered)
                    sql_query = sql_query.replace(
                        '\u003e', '>').replace('\u003c', '<')
                    # Fix regex pattern (if any regex is used in the query)
                    sql_query = sql_query.replace('^[0-9.]+$$', '^[0-9.]+$')

                    # Debug: Print query before cleaning
                    print(f"SQL query before cleaning: {sql_query}")

                    # Ensure proper spacing in the SELECT clause
                    # Add space after comma between quoted columns
                    sql_query = re.sub(
                        r'("[^"]+")\s*,\s*("[^"]+")', r'\1, \2', sql_query)
                    # Ensure space before FROM (case-insensitive match for FROM)
                    # Add space before FROM
                    sql_query = re.sub(
                        r'("[^"]+")(?i)(FROM)', r'\1 FROM', sql_query)

                    # Debug: Print query after cleaning
                    print(f"Cleaned SQL query: {sql_query}")

                    # Basic SQL syntax check
                    if not sql_query.strip().upper().startswith("SELECT"):
                        return f"⚠️ Invalid SQL query: {sql_query}"

                    # Debug: Print final query before execution
                    print(f"Final SQL query before execution: {sql_query}")

                    # Execute the SQL query
                    print(f"Executing SQL Query: {sql_query}")
                    result = execute_sql_query(sql_query)

                    if isinstance(result, list):
                        if result:
                            formatted_result = f"Top {min(len(result), limit)} Songs:\n"
                            # Limit to requested or available songs
                            for i, row in enumerate(result[:limit], 1):
                                track_name = row.get(
                                    "Track Name", "Unknown Track")
                                artist_names = row.get(
                                    "Artist Name(s)", "Unknown Artist")
                                formatted_result += f"{i}. {track_name} by {artist_names}\n"
                            return formatted_result.strip()
                        else:
                            return "🔍 No results found for the query."
                    else:
                        return result  # Error message from execute_sql_query

        # Fallback for no tool calls (e.g., greetings)
        if message_content and not tool_calls:
            # Check if content is a /tool-use block with empty tool_calls
            if '<tool-use>' in message_content and '"tool_calls": []' in message_content:
                return "Hello! How can I help you?"
            return message_content.strip()
        else:
            return "I'm sorry, I couldn't process your request. (No message content or tool calls found)"

    except Exception as e:
        print(e)
        return f"Error: {str(e)}"