DebopamC commited on
Commit
2bc59e7
·
verified ·
1 Parent(s): ad1817e

Upload 10 files

Browse files
Dockerfile CHANGED
@@ -6,7 +6,7 @@ WORKDIR /app
6
  RUN apt-get update && apt-get install -y curl
7
 
8
  # # Download the model
9
- RUN curl -Lo qwen2.5-coder-3b-instruct-q4_k_m.gguf https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct-GGUF/resolve/main/qwen2.5-coder-0.5b-instruct-q4_k_m.gguf?download=true
10
 
11
  # Install build tools required for llama-cpp-python
12
  RUN apt-get update && apt-get install -y build-essential
 
6
  RUN apt-get update && apt-get install -y curl
7
 
8
  # # Download the model
9
+ RUN curl -Lo qwen2.5-coder-3b-instruct-q4_k_m.gguf https://huggingface.co/DebopamC/Text-to-SQL__Qwen2.5-Coder-3B-FineTuned/resolve/main/Text-to-SQL-Qwen2.5-Coder-3B-FineTuned.gguf?download=true
10
 
11
  # Install build tools required for llama-cpp-python
12
  RUN apt-get update && apt-get install -y build-essential
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (180 Bytes). View file
 
utils/__pycache__/handle_sql_commands.cpython-312.pyc ADDED
Binary file (1.17 kB). View file
 
utils/__pycache__/llm_logic.cpython-312.pyc ADDED
Binary file (9.97 kB). View file
 
utils/__pycache__/sql_utils.cpython-312.pyc ADDED
Binary file (3.29 kB). View file
 
utils/llm_logic.py CHANGED
@@ -4,6 +4,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
4
  import streamlit as st
5
  import multiprocessing
6
  from langchain_community.chat_models import ChatLlamaCpp
 
7
 
8
  local_model = "qwen2.5-coder-3b-instruct-q4_k_m.gguf"
9
 
@@ -29,7 +30,7 @@ stop = [
29
  ]
30
 
31
 
32
- def get_llm():
33
  cache_llm = ChatLlamaCpp(
34
  temperature=0.0,
35
  model_path=local_model,
@@ -45,7 +46,20 @@ def get_llm():
45
  return cache_llm
46
 
47
 
48
- llm = get_llm()
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  db_schema = """### **customers**
@@ -151,7 +165,7 @@ QUESTION: {question}
151
  """
152
 
153
 
154
- def classify_question(question: str, use_default_schema: bool = True):
155
  classification_system_prompt_local = classification_system_prompt # Initialize here
156
  if use_default_schema:
157
  classification_system_prompt_local = classification_system_prompt_local.format(
@@ -170,8 +184,14 @@ def classify_question(question: str, use_default_schema: bool = True):
170
  return response.content.strip().upper()
171
 
172
 
173
- def generate_llm_response(prompt: str, use_default_schema: bool = True):
174
- question_type = classify_question(prompt, use_default_schema)
 
 
 
 
 
 
175
  chosen_schema = None
176
  if use_default_schema:
177
  chosen_schema = db_schema
 
4
  import streamlit as st
5
  import multiprocessing
6
  from langchain_community.chat_models import ChatLlamaCpp
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
 
9
  local_model = "qwen2.5-coder-3b-instruct-q4_k_m.gguf"
10
 
 
30
  ]
31
 
32
 
33
+ def get_local_llm():
34
  cache_llm = ChatLlamaCpp(
35
  temperature=0.0,
36
  model_path=local_model,
 
46
  return cache_llm
47
 
48
 
49
+ local_llm = get_local_llm()
50
+
51
+
52
+ def get_gemini_llm():
53
+ gemini = ChatGoogleGenerativeAI(
54
+ model="gemini-1.5-flash",
55
+ temperature=0,
56
+ max_tokens=None,
57
+ timeout=None,
58
+ max_retries=2,
59
+ )
60
+ return gemini
61
+
62
+ gemini_llm = get_gemini_llm()
63
 
64
 
65
  db_schema = """### **customers**
 
165
  """
166
 
167
 
168
+ def classify_question(question: str, llm , use_default_schema: bool = True):
169
  classification_system_prompt_local = classification_system_prompt # Initialize here
170
  if use_default_schema:
171
  classification_system_prompt_local = classification_system_prompt_local.format(
 
184
  return response.content.strip().upper()
185
 
186
 
187
+ def generate_llm_response(prompt: str, llm: str, use_default_schema: bool = True):
188
+
189
+ if llm == "gemini":
190
+ llm = gemini_llm
191
+ else:
192
+ llm = local_llm
193
+
194
+ question_type = classify_question(prompt, llm, use_default_schema)
195
  chosen_schema = None
196
  if use_default_schema:
197
  chosen_schema = db_schema
🤖SQL_Agent.py CHANGED
@@ -18,29 +18,12 @@ st.set_page_config(
18
  initial_sidebar_state="expanded",
19
  )
20
 
21
- default_db_questions = {
22
- "easy": [
23
- "Retrieve all customer IDs and their corresponding cities from the `customers` table.",
24
- "List all products along with their category names from the `products` table.",
25
- "Fetch the order IDs and their purchase timestamps from the `orders` table.",
26
- "Display the distinct payment types available in the `payments` table.",
27
- "Find the total number of rows in the `customers` table.",
28
- ],
29
- "medium": [
30
- "Retrieve the total payment value for each order from the `payments` table, grouped by `order_id`.",
31
- "Find all orders where the total shipping charges (sum of `shipping_charges`) exceed 100.",
32
- "List the names of cities and the number of customers in each city, sorted in descending order of the number of customers.",
33
- ],
34
- "hard": [
35
- "Write a query to find the total revenue (sum of `price` + `shipping_charges`) generated for each product category in the `order_items` table, joined with the `products` table.",
36
- "Identify the top 5 products with the highest total sales value (sum of `price`) across all orders.",
37
- ],
38
- }
39
-
40
-
41
  default_dfs = load_data()
42
  selected_df = default_dfs
43
  use_default_schema = True
 
 
44
 
45
  st.markdown(
46
  """
@@ -139,14 +122,33 @@ st.caption(
139
  )
140
 
141
 
142
- col1, col2 = st.columns([2, 1], vertical_alignment="bottom")
143
  with col1:
144
  # Button to refresh the conversation
145
  if st.button("Start New Conversation", type="primary"):
146
  st.session_state.chat_history = []
147
  st.session_state.conversation_turns = 0
148
  st.rerun()
 
 
 
 
149
  with col2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  disabled_selection = True
151
  if (
152
  "uploaded_dataframes" in st.session_state
@@ -170,10 +172,6 @@ with col2:
170
  selected_df = default_dfs
171
  # print(selected_df)
172
  use_default_schema = True
173
- if selected_df == default_dfs:
174
- with st.popover("Default Database Queries 📚 - Trial"):
175
- default_db_questions = load_defaultdb_queries()
176
- st.markdown(default_db_questions)
177
 
178
  # Initialize chat history in session state
179
  if "chat_history" not in st.session_state:
@@ -234,11 +232,18 @@ if st.session_state.conversation_turns < MAX_TURNS:
234
  with st.chat_message("assistant"):
235
  message_placeholder = st.empty()
236
  full_response = ""
 
 
 
 
 
 
 
237
  with st.spinner(
238
- "I know it is taking a lot of time. To run the model I'm using `Free` small vCPUs provided by `HuggingFace Spaces` for deployment. Thank you so much for your patience😊"
239
  ):
240
  for response_so_far in generate_llm_response(
241
- prompt, use_default_schema
242
  ):
243
  # Remove <sql> and </sql> tags for streaming display
244
  streaming_response = response_so_far.replace("<sql>", "").replace(
 
18
  initial_sidebar_state="expanded",
19
  )
20
 
21
+ default_db_questions = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  default_dfs = load_data()
23
  selected_df = default_dfs
24
  use_default_schema = True
25
+ llm_option = "gemini"
26
+
27
 
28
  st.markdown(
29
  """
 
122
  )
123
 
124
 
125
+ col1, col2, col3 = st.columns([1, 1, 1], vertical_alignment="top")
126
  with col1:
127
  # Button to refresh the conversation
128
  if st.button("Start New Conversation", type="primary"):
129
  st.session_state.chat_history = []
130
  st.session_state.conversation_turns = 0
131
  st.rerun()
132
+ if selected_df == default_dfs:
133
+ with st.popover("Default Database Queries 📚 - Trial"):
134
+ default_db_questions = load_defaultdb_queries()
135
+ st.markdown(default_db_questions)
136
  with col2:
137
+ llm_option_radio = st.radio(
138
+ "Choose LLM Model",
139
+ ["Gemini 1.5-Flash", "FineTuned Qwen2.5-Coder-3B for SQL"],
140
+ captions=[
141
+ "Used via API",
142
+ "Run Locally on this Server. Extremely Slow because of Free vCPUs",
143
+ ],
144
+ label_visibility="collapsed",
145
+ )
146
+ if llm_option_radio == "Gemini 1.5-Flash":
147
+ llm_option = "gemini"
148
+ else:
149
+ llm_option = "qwen"
150
+
151
+ with col3:
152
  disabled_selection = True
153
  if (
154
  "uploaded_dataframes" in st.session_state
 
172
  selected_df = default_dfs
173
  # print(selected_df)
174
  use_default_schema = True
 
 
 
 
175
 
176
  # Initialize chat history in session state
177
  if "chat_history" not in st.session_state:
 
232
  with st.chat_message("assistant"):
233
  message_placeholder = st.empty()
234
  full_response = ""
235
+ spinner_text = ""
236
+ if llm_option == "gemini":
237
+ spinner_text = (
238
+ "Using Gemini-1.5-Flash to run your query. Please wait...😊"
239
+ )
240
+ else:
241
+ spinner_text = "I know it is taking a lot of time. To run the model I'm using `Free` small vCPUs provided by `HuggingFace Spaces` for deployment. Thank you so much for your patience😊"
242
  with st.spinner(
243
+ spinner_text,
244
  ):
245
  for response_so_far in generate_llm_response(
246
+ prompt, llm_option, use_default_schema
247
  ):
248
  # Remove <sql> and </sql> tags for streaming display
249
  streaming_response = response_so_far.replace("<sql>", "").replace(