Invicto69 commited on
Commit
6cfa9f4
·
verified ·
1 Parent(s): 89b8387

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (3) hide show
  1. app.py +38 -72
  2. utils.py +10 -1
  3. var.py +62 -0
app.py CHANGED
@@ -1,5 +1,7 @@
1
  from typing import Generator
2
- from utils import get_all_groq_model, validate_api_key, get_info, validate_uri
 
 
3
  import streamlit as st
4
  from groq import Groq
5
 
@@ -17,39 +19,19 @@ st.markdown("# SQL Chat")
17
  st.sidebar.title("Settings")
18
  api_key = st.sidebar.text_input("Groq API Key", type="password")
19
 
20
- models = []
21
-
22
- @st.cache_data
23
- def get_text_models(api_key):
24
- models = get_all_groq_model(api_key=api_key)
25
- vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
26
- models = [model for model in models if model not in vision_audio]
27
- return models
28
-
29
  # validating api_key
30
  if not validate_api_key(api_key):
31
  st.sidebar.error("Enter valid API Key")
 
32
  else:
33
  st.sidebar.success("API Key is valid")
34
- models = get_text_models(api_key)
35
-
36
- model = st.sidebar.selectbox("Select Model", models)
37
 
38
  if st.session_state.selected_model != model:
39
  st.session_state.messages = []
40
  st.session_state.selected_model = model
41
 
42
-
43
  uri = st.sidebar.text_input("Enter SQL Database URI")
44
- db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
45
- markdown_info = """
46
- **SQL Dialect**: {sql_dialect}\n
47
- **Tables**: {tables}\n
48
- **Tables Schema**:
49
- ```sql
50
- {tables_schema}
51
- ```
52
- """
53
 
54
  if not validate_uri(uri):
55
  st.sidebar.error("Enter valid URI")
@@ -59,52 +41,20 @@ else:
59
  markdown_info = markdown_info.format(**db_info)
60
  with st.expander("SQL Database Info"):
61
  st.markdown(markdown_info)
 
62
 
63
- system_prompt = f"""
64
- You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
65
- You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
66
- correctness, efficiency, and security in your SQL queries.\
67
-
68
- ## SQL Database Info
69
- {markdown_info}
70
-
71
- ---
72
-
73
- ## Query Generation Guidelines
74
- 1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
75
- 2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
76
- 3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
77
- 4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
78
- 5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
79
- 6. **Commenting**: Include comments in complex queries to explain logic when needed.
80
-
81
- ---
82
-
83
- ## Expected Output Format
84
-
85
- The SQL query should be returned as a formatted code block:
86
-
87
- ```sql
88
- -- Get all completed orders with user details
89
- SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
90
- FROM orders
91
- JOIN users ON orders.user_id = users.id
92
- WHERE orders.status = 'completed'
93
- ORDER BY orders.created_at DESC;
94
- ```
95
-
96
- If the user's request is ambiguous, ask clarifying questions before generating the query.
97
- """
98
-
99
- if model is not None and validate_uri(uri):
100
  client = Groq(
101
  api_key=api_key,
102
  )
103
 
 
 
 
 
104
  # Display chat messages from history on app rerun
105
  for message in st.session_state.messages:
106
- avatar = '🤖' if message["role"] == "assistant" else '👨‍💻'
107
- with st.chat_message(message["role"], avatar=avatar):
108
  st.markdown(message["content"])
109
 
110
 
@@ -118,7 +68,7 @@ if model is not None and validate_uri(uri):
118
  if prompt := st.chat_input("Enter your prompt here..."):
119
  st.session_state.messages.append({"role": "user", "content": prompt})
120
 
121
- with st.chat_message("user", avatar='👨‍💻'):
122
  st.markdown(prompt)
123
 
124
  # Fetch response from Groq API
@@ -135,25 +85,41 @@ if model is not None and validate_uri(uri):
135
  "role": m["role"],
136
  "content": m["content"]
137
  }
138
- for m in st.session_state.messages
139
  ],
140
  max_tokens=3000,
141
  stream=True
142
  )
143
 
144
  # Use the generator function with st.write_stream
145
- with st.chat_message("SQL Assistant", avatar="🤖"):
146
  chat_responses_generator = generate_chat_responses(chat_completion)
147
- full_response = st.write_stream(chat_responses_generator)
 
 
 
 
 
 
148
  except Exception as e:
149
  st.error(e, icon="🚨")
150
 
151
- # Append the full response to session_state.messages
152
- if isinstance(full_response, str):
 
 
 
 
 
153
  st.session_state.messages.append(
154
- {"role": "assistant", "content": full_response})
155
  else:
156
- # Handle the case where full_response is not a string
157
- combined_response = "\n".join(str(item) for item in full_response)
158
  st.session_state.messages.append(
159
- {"role": "assistant", "content": combined_response})
 
 
 
 
 
 
1
  from typing import Generator
2
+ from utils import validate_api_key, get_info, validate_uri, extract_code_blocks
3
+ from langchain_community.utilities import SQLDatabase
4
+ from var import system_prompt, markdown_info, query_output, groq_models
5
  import streamlit as st
6
  from groq import Groq
7
 
 
19
  st.sidebar.title("Settings")
20
  api_key = st.sidebar.text_input("Groq API Key", type="password")
21
 
 
 
 
 
 
 
 
 
 
22
  # validating api_key
23
  if not validate_api_key(api_key):
24
  st.sidebar.error("Enter valid API Key")
25
+ model = st.sidebar.selectbox("Select Model", groq_models, disabled=True)
26
  else:
27
  st.sidebar.success("API Key is valid")
28
+ model = st.sidebar.selectbox("Select Model", groq_models, index=0)
 
 
29
 
30
  if st.session_state.selected_model != model:
31
  st.session_state.messages = []
32
  st.session_state.selected_model = model
33
 
 
34
  uri = st.sidebar.text_input("Enter SQL Database URI")
 
 
 
 
 
 
 
 
 
35
 
36
  if not validate_uri(uri):
37
  st.sidebar.error("Enter valid URI")
 
41
  markdown_info = markdown_info.format(**db_info)
42
  with st.expander("SQL Database Info"):
43
  st.markdown(markdown_info)
44
+ system_prompt = system_prompt.format(markdown_info = markdown_info)
45
 
46
+ if validate_api_key(api_key) and validate_uri(uri):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  client = Groq(
48
  api_key=api_key,
49
  )
50
 
51
+ db = SQLDatabase.from_uri(uri)
52
+
53
+ avatar = {"user": '👨‍💻', "assistant": '🤖', "executor": '🛢'}
54
+
55
  # Display chat messages from history on app rerun
56
  for message in st.session_state.messages:
57
+ with st.chat_message(message["role"], avatar=avatar[message["role"]]):
 
58
  st.markdown(message["content"])
59
 
60
 
 
68
  if prompt := st.chat_input("Enter your prompt here..."):
69
  st.session_state.messages.append({"role": "user", "content": prompt})
70
 
71
+ with st.chat_message("user", avatar=avatar["user"]):
72
  st.markdown(prompt)
73
 
74
  # Fetch response from Groq API
 
85
  "role": m["role"],
86
  "content": m["content"]
87
  }
88
+ for m in st.session_state.messages[-8:]
89
  ],
90
  max_tokens=3000,
91
  stream=True
92
  )
93
 
94
  # Use the generator function with st.write_stream
95
+ with st.chat_message("SQL Assistant", avatar=avatar["assistant"]):
96
  chat_responses_generator = generate_chat_responses(chat_completion)
97
+ llm_response = st.write_stream(chat_responses_generator)
98
+
99
+ with st.chat_message("SQL Executor", avatar=avatar["executor"]):
100
+ query = extract_code_blocks(llm_response)
101
+ result = db.run(query[0])
102
+ query_response = st.write(query_output.format(result=result))
103
+
104
  except Exception as e:
105
  st.error(e, icon="🚨")
106
 
107
+ if len(str(result)) > 1000:
108
+ query_output_truncated = query_output.format(result=result)[:500]+query_output.format(result=result)[-500:]
109
+ else:
110
+ query_output_truncated = query_output.format(result=result)
111
+
112
+ # Append the llm response to session_state.messages
113
+ if isinstance(llm_response, str):
114
  st.session_state.messages.append(
115
+ {"role": "assistant", "content": llm_response + query_output_truncated})
116
  else:
117
+ # Handle the case where llm_response is not a string
118
+ combined_response = "\n".join(str(item) for item in llm_response)
119
  st.session_state.messages.append(
120
+ {"role": "assistant", "content": combined_response + query_output_truncated})
121
+
122
+ st.sidebar.button("Clear Chat History", on_click=lambda: st.session_state.messages.clear())
123
+
124
+ else:
125
+ st.error("Please enter valid Groq API Key and URI in the sidebar.")
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import requests
2
  from langchain_community.utilities import SQLDatabase
3
  from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
 
4
 
5
  def get_all_groq_model(api_key:str=None) -> list:
6
  if api_key is None:
@@ -44,5 +45,13 @@ def get_info(uri:str) -> dict[str, str] | None:
44
  tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
45
  return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
46
 
 
 
 
 
 
47
  if __name__ == "__main__":
48
- print(get_all_groq_model())
 
 
 
 
1
  import requests
2
  from langchain_community.utilities import SQLDatabase
3
  from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
4
+ import re
5
 
6
  def get_all_groq_model(api_key:str=None) -> list:
7
  if api_key is None:
 
45
  tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
46
  return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
47
 
48
+ def extract_code_blocks(text):
49
+ pattern = r"```(?:\w+)?\n(.*?)\n```"
50
+ matches = re.findall(pattern, text, re.DOTALL)
51
+ return matches
52
+
53
  if __name__ == "__main__":
54
+ models = (get_all_groq_model("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"))
55
+ vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
56
+ models = [model for model in models if model not in vision_audio]
57
+ print(models)
var.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ groq_models = ['llama-3.3-70b-versatile', 'gemma2-9b-it', 'llama-3.2-3b-preview', 'deepseek-r1-distill-llama-70b', 'qwen-2.5-coder-32b',
2
+ 'mixtral-8x7b-32768', 'llama-3.1-8b-instant', 'llama-3.2-1b-preview', 'allam-2-7b', 'qwen-qwq-32b', 'llama3-70b-8192',
3
+ 'mistral-saba-24b', 'deepseek-r1-distill-qwen-32b', 'qwen-2.5-32b', 'llama-3.3-70b-specdec', 'llama3-8b-8192', 'llama-guard-3-8b']
4
+
5
+ db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
6
+
7
+ markdown_info = """
8
+ **SQL Dialect**: {sql_dialect}\n
9
+ **Tables**: {tables}\n
10
+ **Tables Schema**:
11
+ ```sql
12
+ {tables_schema}
13
+ ```
14
+ """
15
+
16
+ system_prompt = """
17
+ You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
18
+ You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
19
+ correctness, efficiency, and security in your SQL queries.\
20
+
21
+ ## SQL Database Info
22
+ {markdown_info}
23
+
24
+ ---
25
+
26
+ ## Query Generation Guidelines
27
+ 1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
28
+ 2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
29
+ 3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
30
+ 4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
31
+ 5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
32
+ 6. **Commenting**: Include comments in complex queries to explain logic when needed.
33
+ 7. **Result**: Don't return the result of the query, just the SQL query.
34
+ 8. **Optimal**: Try to generate query which is optimal and not brute force.
35
+ 9. **Single query**: Generate a best single SQL query for the user input.'
36
+ 10. **Comment**: Include comments in the query to explain the logic behind it.
37
+
38
+ ---
39
+
40
+ ## Expected Output Format
41
+
42
+ The SQL query should be returned as a formatted code block:
43
+
44
+ ```sql
45
+ -- Get all completed orders with user details
46
+ -- Comment explaining the logic.
47
+ SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
48
+ FROM orders
49
+ JOIN users ON orders.user_id = users.id
50
+ WHERE orders.status = 'completed'
51
+ ORDER BY orders.created_at DESC;
52
+ ```
53
+
54
+ If the user's request is ambiguous, ask clarifying questions before generating the query.
55
+ """
56
+
57
+ query_output = """
58
+ **The result of query execution:**
59
+ ```sql
60
+ {result}
61
+ ```
62
+ """