Synced repo using 'sync_with_huggingface' Github Action
Browse files
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from typing import Generator
|
2 |
-
from utils import
|
|
|
|
|
3 |
import streamlit as st
|
4 |
from groq import Groq
|
5 |
|
@@ -17,39 +19,19 @@ st.markdown("# SQL Chat")
|
|
17 |
st.sidebar.title("Settings")
|
18 |
api_key = st.sidebar.text_input("Groq API Key", type="password")
|
19 |
|
20 |
-
models = []
|
21 |
-
|
22 |
-
@st.cache_data
|
23 |
-
def get_text_models(api_key):
|
24 |
-
models = get_all_groq_model(api_key=api_key)
|
25 |
-
vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
|
26 |
-
models = [model for model in models if model not in vision_audio]
|
27 |
-
return models
|
28 |
-
|
29 |
# validating api_key
|
30 |
if not validate_api_key(api_key):
|
31 |
st.sidebar.error("Enter valid API Key")
|
|
|
32 |
else:
|
33 |
st.sidebar.success("API Key is valid")
|
34 |
-
|
35 |
-
|
36 |
-
model = st.sidebar.selectbox("Select Model", models)
|
37 |
|
38 |
if st.session_state.selected_model != model:
|
39 |
st.session_state.messages = []
|
40 |
st.session_state.selected_model = model
|
41 |
|
42 |
-
|
43 |
uri = st.sidebar.text_input("Enter SQL Database URI")
|
44 |
-
db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
|
45 |
-
markdown_info = """
|
46 |
-
**SQL Dialect**: {sql_dialect}\n
|
47 |
-
**Tables**: {tables}\n
|
48 |
-
**Tables Schema**:
|
49 |
-
```sql
|
50 |
-
{tables_schema}
|
51 |
-
```
|
52 |
-
"""
|
53 |
|
54 |
if not validate_uri(uri):
|
55 |
st.sidebar.error("Enter valid URI")
|
@@ -59,52 +41,20 @@ else:
|
|
59 |
markdown_info = markdown_info.format(**db_info)
|
60 |
with st.expander("SQL Database Info"):
|
61 |
st.markdown(markdown_info)
|
|
|
62 |
|
63 |
-
|
64 |
-
You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
|
65 |
-
You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
|
66 |
-
correctness, efficiency, and security in your SQL queries.\
|
67 |
-
|
68 |
-
## SQL Database Info
|
69 |
-
{markdown_info}
|
70 |
-
|
71 |
-
---
|
72 |
-
|
73 |
-
## Query Generation Guidelines
|
74 |
-
1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
|
75 |
-
2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
|
76 |
-
3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
|
77 |
-
4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
|
78 |
-
5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
|
79 |
-
6. **Commenting**: Include comments in complex queries to explain logic when needed.
|
80 |
-
|
81 |
-
---
|
82 |
-
|
83 |
-
## Expected Output Format
|
84 |
-
|
85 |
-
The SQL query should be returned as a formatted code block:
|
86 |
-
|
87 |
-
```sql
|
88 |
-
-- Get all completed orders with user details
|
89 |
-
SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
|
90 |
-
FROM orders
|
91 |
-
JOIN users ON orders.user_id = users.id
|
92 |
-
WHERE orders.status = 'completed'
|
93 |
-
ORDER BY orders.created_at DESC;
|
94 |
-
```
|
95 |
-
|
96 |
-
If the user's request is ambiguous, ask clarifying questions before generating the query.
|
97 |
-
"""
|
98 |
-
|
99 |
-
if model is not None and validate_uri(uri):
|
100 |
client = Groq(
|
101 |
api_key=api_key,
|
102 |
)
|
103 |
|
|
|
|
|
|
|
|
|
104 |
# Display chat messages from history on app rerun
|
105 |
for message in st.session_state.messages:
|
106 |
-
|
107 |
-
with st.chat_message(message["role"], avatar=avatar):
|
108 |
st.markdown(message["content"])
|
109 |
|
110 |
|
@@ -118,7 +68,7 @@ if model is not None and validate_uri(uri):
|
|
118 |
if prompt := st.chat_input("Enter your prompt here..."):
|
119 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
120 |
|
121 |
-
with st.chat_message("user", avatar=
|
122 |
st.markdown(prompt)
|
123 |
|
124 |
# Fetch response from Groq API
|
@@ -135,25 +85,41 @@ if model is not None and validate_uri(uri):
|
|
135 |
"role": m["role"],
|
136 |
"content": m["content"]
|
137 |
}
|
138 |
-
for m in st.session_state.messages
|
139 |
],
|
140 |
max_tokens=3000,
|
141 |
stream=True
|
142 |
)
|
143 |
|
144 |
# Use the generator function with st.write_stream
|
145 |
-
with st.chat_message("SQL Assistant", avatar="
|
146 |
chat_responses_generator = generate_chat_responses(chat_completion)
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
except Exception as e:
|
149 |
st.error(e, icon="🚨")
|
150 |
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
153 |
st.session_state.messages.append(
|
154 |
-
{"role": "assistant", "content":
|
155 |
else:
|
156 |
-
# Handle the case where
|
157 |
-
combined_response = "\n".join(str(item) for item in
|
158 |
st.session_state.messages.append(
|
159 |
-
{"role": "assistant", "content": combined_response})
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Generator
|
2 |
+
from utils import validate_api_key, get_info, validate_uri, extract_code_blocks
|
3 |
+
from langchain_community.utilities import SQLDatabase
|
4 |
+
from var import system_prompt, markdown_info, query_output, groq_models
|
5 |
import streamlit as st
|
6 |
from groq import Groq
|
7 |
|
|
|
19 |
st.sidebar.title("Settings")
|
20 |
api_key = st.sidebar.text_input("Groq API Key", type="password")
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# validating api_key
|
23 |
if not validate_api_key(api_key):
|
24 |
st.sidebar.error("Enter valid API Key")
|
25 |
+
model = st.sidebar.selectbox("Select Model", groq_models, disabled=True)
|
26 |
else:
|
27 |
st.sidebar.success("API Key is valid")
|
28 |
+
model = st.sidebar.selectbox("Select Model", groq_models, index=0)
|
|
|
|
|
29 |
|
30 |
if st.session_state.selected_model != model:
|
31 |
st.session_state.messages = []
|
32 |
st.session_state.selected_model = model
|
33 |
|
|
|
34 |
uri = st.sidebar.text_input("Enter SQL Database URI")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
if not validate_uri(uri):
|
37 |
st.sidebar.error("Enter valid URI")
|
|
|
41 |
markdown_info = markdown_info.format(**db_info)
|
42 |
with st.expander("SQL Database Info"):
|
43 |
st.markdown(markdown_info)
|
44 |
+
system_prompt = system_prompt.format(markdown_info = markdown_info)
|
45 |
|
46 |
+
if validate_api_key(api_key) and validate_uri(uri):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
client = Groq(
|
48 |
api_key=api_key,
|
49 |
)
|
50 |
|
51 |
+
db = SQLDatabase.from_uri(uri)
|
52 |
+
|
53 |
+
avatar = {"user": '👨💻', "assistant": '🤖', "executor": '🛢'}
|
54 |
+
|
55 |
# Display chat messages from history on app rerun
|
56 |
for message in st.session_state.messages:
|
57 |
+
with st.chat_message(message["role"], avatar=avatar[message["role"]]):
|
|
|
58 |
st.markdown(message["content"])
|
59 |
|
60 |
|
|
|
68 |
if prompt := st.chat_input("Enter your prompt here..."):
|
69 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
70 |
|
71 |
+
with st.chat_message("user", avatar=avatar["user"]):
|
72 |
st.markdown(prompt)
|
73 |
|
74 |
# Fetch response from Groq API
|
|
|
85 |
"role": m["role"],
|
86 |
"content": m["content"]
|
87 |
}
|
88 |
+
for m in st.session_state.messages[-8:]
|
89 |
],
|
90 |
max_tokens=3000,
|
91 |
stream=True
|
92 |
)
|
93 |
|
94 |
# Use the generator function with st.write_stream
|
95 |
+
with st.chat_message("SQL Assistant", avatar=avatar["assistant"]):
|
96 |
chat_responses_generator = generate_chat_responses(chat_completion)
|
97 |
+
llm_response = st.write_stream(chat_responses_generator)
|
98 |
+
|
99 |
+
with st.chat_message("SQL Executor", avatar=avatar["executor"]):
|
100 |
+
query = extract_code_blocks(llm_response)
|
101 |
+
result = db.run(query[0])
|
102 |
+
query_response = st.write(query_output.format(result=result))
|
103 |
+
|
104 |
except Exception as e:
|
105 |
st.error(e, icon="🚨")
|
106 |
|
107 |
+
if len(str(result)) > 1000:
|
108 |
+
query_output_truncated = query_output.format(result=result)[:500]+query_output.format(result=result)[-500:]
|
109 |
+
else:
|
110 |
+
query_output_truncated = query_output.format(result=result)
|
111 |
+
|
112 |
+
# Append the llm response to session_state.messages
|
113 |
+
if isinstance(llm_response, str):
|
114 |
st.session_state.messages.append(
|
115 |
+
{"role": "assistant", "content": llm_response + query_output_truncated})
|
116 |
else:
|
117 |
+
# Handle the case where llm_response is not a string
|
118 |
+
combined_response = "\n".join(str(item) for item in llm_response)
|
119 |
st.session_state.messages.append(
|
120 |
+
{"role": "assistant", "content": combined_response + query_output_truncated})
|
121 |
+
|
122 |
+
st.sidebar.button("Clear Chat History", on_click=lambda: st.session_state.messages.clear())
|
123 |
+
|
124 |
+
else:
|
125 |
+
st.error("Please enter valid Groq API Key and URI in the sidebar.")
|
utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import requests
|
2 |
from langchain_community.utilities import SQLDatabase
|
3 |
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
|
|
|
4 |
|
5 |
def get_all_groq_model(api_key:str=None) -> list:
|
6 |
if api_key is None:
|
@@ -44,5 +45,13 @@ def get_info(uri:str) -> dict[str, str] | None:
|
|
44 |
tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
|
45 |
return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
|
46 |
|
|
|
|
|
|
|
|
|
|
|
47 |
if __name__ == "__main__":
|
48 |
-
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
from langchain_community.utilities import SQLDatabase
|
3 |
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
|
4 |
+
import re
|
5 |
|
6 |
def get_all_groq_model(api_key:str=None) -> list:
|
7 |
if api_key is None:
|
|
|
45 |
tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
|
46 |
return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
|
47 |
|
48 |
+
def extract_code_blocks(text):
|
49 |
+
pattern = r"```(?:\w+)?\n(.*?)\n```"
|
50 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
51 |
+
return matches
|
52 |
+
|
53 |
if __name__ == "__main__":
|
54 |
+
models = (get_all_groq_model("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"))
|
55 |
+
vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
|
56 |
+
models = [model for model in models if model not in vision_audio]
|
57 |
+
print(models)
|
var.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
groq_models = ['llama-3.3-70b-versatile', 'gemma2-9b-it', 'llama-3.2-3b-preview', 'deepseek-r1-distill-llama-70b', 'qwen-2.5-coder-32b',
|
2 |
+
'mixtral-8x7b-32768', 'llama-3.1-8b-instant', 'llama-3.2-1b-preview', 'allam-2-7b', 'qwen-qwq-32b', 'llama3-70b-8192',
|
3 |
+
'mistral-saba-24b', 'deepseek-r1-distill-qwen-32b', 'qwen-2.5-32b', 'llama-3.3-70b-specdec', 'llama3-8b-8192', 'llama-guard-3-8b']
|
4 |
+
|
5 |
+
db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
|
6 |
+
|
7 |
+
markdown_info = """
|
8 |
+
**SQL Dialect**: {sql_dialect}\n
|
9 |
+
**Tables**: {tables}\n
|
10 |
+
**Tables Schema**:
|
11 |
+
```sql
|
12 |
+
{tables_schema}
|
13 |
+
```
|
14 |
+
"""
|
15 |
+
|
16 |
+
system_prompt = """
|
17 |
+
You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
|
18 |
+
You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
|
19 |
+
correctness, efficiency, and security in your SQL queries.\
|
20 |
+
|
21 |
+
## SQL Database Info
|
22 |
+
{markdown_info}
|
23 |
+
|
24 |
+
---
|
25 |
+
|
26 |
+
## Query Generation Guidelines
|
27 |
+
1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
|
28 |
+
2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
|
29 |
+
3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
|
30 |
+
4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
|
31 |
+
5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
|
32 |
+
6. **Commenting**: Include comments in complex queries to explain logic when needed.
|
33 |
+
7. **Result**: Don't return the result of the query, just the SQL query.
|
34 |
+
8. **Optimal**: Try to generate query which is optimal and not brute force.
|
35 |
+
9. **Single query**: Generate a best single SQL query for the user input.'
|
36 |
+
10. **Comment**: Include comments in the query to explain the logic behind it.
|
37 |
+
|
38 |
+
---
|
39 |
+
|
40 |
+
## Expected Output Format
|
41 |
+
|
42 |
+
The SQL query should be returned as a formatted code block:
|
43 |
+
|
44 |
+
```sql
|
45 |
+
-- Get all completed orders with user details
|
46 |
+
-- Comment explaining the logic.
|
47 |
+
SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
|
48 |
+
FROM orders
|
49 |
+
JOIN users ON orders.user_id = users.id
|
50 |
+
WHERE orders.status = 'completed'
|
51 |
+
ORDER BY orders.created_at DESC;
|
52 |
+
```
|
53 |
+
|
54 |
+
If the user's request is ambiguous, ask clarifying questions before generating the query.
|
55 |
+
"""
|
56 |
+
|
57 |
+
query_output = """
|
58 |
+
**The result of query execution:**
|
59 |
+
```sql
|
60 |
+
{result}
|
61 |
+
```
|
62 |
+
"""
|