Upload 10 files
Browse files
Dockerfile
CHANGED
@@ -6,7 +6,7 @@ WORKDIR /app
|
|
6 |
RUN apt-get update && apt-get install -y curl
|
7 |
|
8 |
# # Download the model
|
9 |
-
RUN curl -Lo qwen2.5-coder-3b-instruct-q4_k_m.gguf https://huggingface.co/
|
10 |
|
11 |
# Install build tools required for llama-cpp-python
|
12 |
RUN apt-get update && apt-get install -y build-essential
|
|
|
6 |
RUN apt-get update && apt-get install -y curl
|
7 |
|
8 |
# # Download the model
|
9 |
+
RUN curl -Lo qwen2.5-coder-3b-instruct-q4_k_m.gguf https://huggingface.co/DebopamC/Text-to-SQL__Qwen2.5-Coder-3B-FineTuned/resolve/main/Text-to-SQL-Qwen2.5-Coder-3B-FineTuned.gguf?download=true
|
10 |
|
11 |
# Install build tools required for llama-cpp-python
|
12 |
RUN apt-get update && apt-get install -y build-essential
|
utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (180 Bytes). View file
|
|
utils/__pycache__/handle_sql_commands.cpython-312.pyc
ADDED
Binary file (1.17 kB). View file
|
|
utils/__pycache__/llm_logic.cpython-312.pyc
ADDED
Binary file (9.97 kB). View file
|
|
utils/__pycache__/sql_utils.cpython-312.pyc
ADDED
Binary file (3.29 kB). View file
|
|
utils/llm_logic.py
CHANGED
@@ -4,6 +4,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|
4 |
import streamlit as st
|
5 |
import multiprocessing
|
6 |
from langchain_community.chat_models import ChatLlamaCpp
|
|
|
7 |
|
8 |
local_model = "qwen2.5-coder-3b-instruct-q4_k_m.gguf"
|
9 |
|
@@ -29,7 +30,7 @@ stop = [
|
|
29 |
]
|
30 |
|
31 |
|
32 |
-
def
|
33 |
cache_llm = ChatLlamaCpp(
|
34 |
temperature=0.0,
|
35 |
model_path=local_model,
|
@@ -45,7 +46,20 @@ def get_llm():
|
|
45 |
return cache_llm
|
46 |
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
db_schema = """### **customers**
|
@@ -151,7 +165,7 @@ QUESTION: {question}
|
|
151 |
"""
|
152 |
|
153 |
|
154 |
-
def classify_question(question: str, use_default_schema: bool = True):
|
155 |
classification_system_prompt_local = classification_system_prompt # Initialize here
|
156 |
if use_default_schema:
|
157 |
classification_system_prompt_local = classification_system_prompt_local.format(
|
@@ -170,8 +184,14 @@ def classify_question(question: str, use_default_schema: bool = True):
|
|
170 |
return response.content.strip().upper()
|
171 |
|
172 |
|
173 |
-
def generate_llm_response(prompt: str, use_default_schema: bool = True):
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
chosen_schema = None
|
176 |
if use_default_schema:
|
177 |
chosen_schema = db_schema
|
|
|
4 |
import streamlit as st
|
5 |
import multiprocessing
|
6 |
from langchain_community.chat_models import ChatLlamaCpp
|
7 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
8 |
|
9 |
local_model = "qwen2.5-coder-3b-instruct-q4_k_m.gguf"
|
10 |
|
|
|
30 |
]
|
31 |
|
32 |
|
33 |
+
def get_local_llm():
|
34 |
cache_llm = ChatLlamaCpp(
|
35 |
temperature=0.0,
|
36 |
model_path=local_model,
|
|
|
46 |
return cache_llm
|
47 |
|
48 |
|
49 |
+
local_llm = get_local_llm()
|
50 |
+
|
51 |
+
|
52 |
+
def get_gemini_llm():
|
53 |
+
gemini = ChatGoogleGenerativeAI(
|
54 |
+
model="gemini-1.5-flash",
|
55 |
+
temperature=0,
|
56 |
+
max_tokens=None,
|
57 |
+
timeout=None,
|
58 |
+
max_retries=2,
|
59 |
+
)
|
60 |
+
return gemini
|
61 |
+
|
62 |
+
gemini_llm = get_gemini_llm()
|
63 |
|
64 |
|
65 |
db_schema = """### **customers**
|
|
|
165 |
"""
|
166 |
|
167 |
|
168 |
+
def classify_question(question: str, llm , use_default_schema: bool = True):
|
169 |
classification_system_prompt_local = classification_system_prompt # Initialize here
|
170 |
if use_default_schema:
|
171 |
classification_system_prompt_local = classification_system_prompt_local.format(
|
|
|
184 |
return response.content.strip().upper()
|
185 |
|
186 |
|
187 |
+
def generate_llm_response(prompt: str, llm: str, use_default_schema: bool = True):
|
188 |
+
|
189 |
+
if llm == "gemini":
|
190 |
+
llm = gemini_llm
|
191 |
+
else:
|
192 |
+
llm = local_llm
|
193 |
+
|
194 |
+
question_type = classify_question(prompt, llm, use_default_schema)
|
195 |
chosen_schema = None
|
196 |
if use_default_schema:
|
197 |
chosen_schema = db_schema
|
🤖SQL_Agent.py
CHANGED
@@ -18,29 +18,12 @@ st.set_page_config(
|
|
18 |
initial_sidebar_state="expanded",
|
19 |
)
|
20 |
|
21 |
-
default_db_questions = {
|
22 |
-
"easy": [
|
23 |
-
"Retrieve all customer IDs and their corresponding cities from the `customers` table.",
|
24 |
-
"List all products along with their category names from the `products` table.",
|
25 |
-
"Fetch the order IDs and their purchase timestamps from the `orders` table.",
|
26 |
-
"Display the distinct payment types available in the `payments` table.",
|
27 |
-
"Find the total number of rows in the `customers` table.",
|
28 |
-
],
|
29 |
-
"medium": [
|
30 |
-
"Retrieve the total payment value for each order from the `payments` table, grouped by `order_id`.",
|
31 |
-
"Find all orders where the total shipping charges (sum of `shipping_charges`) exceed 100.",
|
32 |
-
"List the names of cities and the number of customers in each city, sorted in descending order of the number of customers.",
|
33 |
-
],
|
34 |
-
"hard": [
|
35 |
-
"Write a query to find the total revenue (sum of `price` + `shipping_charges`) generated for each product category in the `order_items` table, joined with the `products` table.",
|
36 |
-
"Identify the top 5 products with the highest total sales value (sum of `price`) across all orders.",
|
37 |
-
],
|
38 |
-
}
|
39 |
-
|
40 |
-
|
41 |
default_dfs = load_data()
|
42 |
selected_df = default_dfs
|
43 |
use_default_schema = True
|
|
|
|
|
44 |
|
45 |
st.markdown(
|
46 |
"""
|
@@ -139,14 +122,33 @@ st.caption(
|
|
139 |
)
|
140 |
|
141 |
|
142 |
-
col1, col2 = st.columns([
|
143 |
with col1:
|
144 |
# Button to refresh the conversation
|
145 |
if st.button("Start New Conversation", type="primary"):
|
146 |
st.session_state.chat_history = []
|
147 |
st.session_state.conversation_turns = 0
|
148 |
st.rerun()
|
|
|
|
|
|
|
|
|
149 |
with col2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
disabled_selection = True
|
151 |
if (
|
152 |
"uploaded_dataframes" in st.session_state
|
@@ -170,10 +172,6 @@ with col2:
|
|
170 |
selected_df = default_dfs
|
171 |
# print(selected_df)
|
172 |
use_default_schema = True
|
173 |
-
if selected_df == default_dfs:
|
174 |
-
with st.popover("Default Database Queries 📚 - Trial"):
|
175 |
-
default_db_questions = load_defaultdb_queries()
|
176 |
-
st.markdown(default_db_questions)
|
177 |
|
178 |
# Initialize chat history in session state
|
179 |
if "chat_history" not in st.session_state:
|
@@ -234,11 +232,18 @@ if st.session_state.conversation_turns < MAX_TURNS:
|
|
234 |
with st.chat_message("assistant"):
|
235 |
message_placeholder = st.empty()
|
236 |
full_response = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
with st.spinner(
|
238 |
-
|
239 |
):
|
240 |
for response_so_far in generate_llm_response(
|
241 |
-
prompt, use_default_schema
|
242 |
):
|
243 |
# Remove <sql> and </sql> tags for streaming display
|
244 |
streaming_response = response_so_far.replace("<sql>", "").replace(
|
|
|
18 |
initial_sidebar_state="expanded",
|
19 |
)
|
20 |
|
21 |
+
default_db_questions = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
default_dfs = load_data()
|
23 |
selected_df = default_dfs
|
24 |
use_default_schema = True
|
25 |
+
llm_option = "gemini"
|
26 |
+
|
27 |
|
28 |
st.markdown(
|
29 |
"""
|
|
|
122 |
)
|
123 |
|
124 |
|
125 |
+
col1, col2, col3 = st.columns([1, 1, 1], vertical_alignment="top")
|
126 |
with col1:
|
127 |
# Button to refresh the conversation
|
128 |
if st.button("Start New Conversation", type="primary"):
|
129 |
st.session_state.chat_history = []
|
130 |
st.session_state.conversation_turns = 0
|
131 |
st.rerun()
|
132 |
+
if selected_df == default_dfs:
|
133 |
+
with st.popover("Default Database Queries 📚 - Trial"):
|
134 |
+
default_db_questions = load_defaultdb_queries()
|
135 |
+
st.markdown(default_db_questions)
|
136 |
with col2:
|
137 |
+
llm_option_radio = st.radio(
|
138 |
+
"Choose LLM Model",
|
139 |
+
["Gemini 1.5-Flash", "FineTuned Qwen2.5-Coder-3B for SQL"],
|
140 |
+
captions=[
|
141 |
+
"Used via API",
|
142 |
+
"Run Locally on this Server. Extremely Slow because of Free vCPUs",
|
143 |
+
],
|
144 |
+
label_visibility="collapsed",
|
145 |
+
)
|
146 |
+
if llm_option_radio == "Gemini 1.5-Flash":
|
147 |
+
llm_option = "gemini"
|
148 |
+
else:
|
149 |
+
llm_option = "qwen"
|
150 |
+
|
151 |
+
with col3:
|
152 |
disabled_selection = True
|
153 |
if (
|
154 |
"uploaded_dataframes" in st.session_state
|
|
|
172 |
selected_df = default_dfs
|
173 |
# print(selected_df)
|
174 |
use_default_schema = True
|
|
|
|
|
|
|
|
|
175 |
|
176 |
# Initialize chat history in session state
|
177 |
if "chat_history" not in st.session_state:
|
|
|
232 |
with st.chat_message("assistant"):
|
233 |
message_placeholder = st.empty()
|
234 |
full_response = ""
|
235 |
+
spinner_text = ""
|
236 |
+
if llm_option == "gemini":
|
237 |
+
spinner_text = (
|
238 |
+
"Using Gemini-1.5-Flash to run your query. Please wait...😊"
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
spinner_text = "I know it is taking a lot of time. To run the model I'm using `Free` small vCPUs provided by `HuggingFace Spaces` for deployment. Thank you so much for your patience😊"
|
242 |
with st.spinner(
|
243 |
+
spinner_text,
|
244 |
):
|
245 |
for response_so_far in generate_llm_response(
|
246 |
+
prompt, llm_option, use_default_schema
|
247 |
):
|
248 |
# Remove <sql> and </sql> tags for streaming display
|
249 |
streaming_response = response_so_far.replace("<sql>", "").replace(
|