Ari commited on
Commit
bb31796
·
verified ·
1 Parent(s): e14b81b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -109
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
- from langchain import OpenAI, LLMChain, PromptTemplate
6
  import sqlparse
7
  import logging
8
 
@@ -10,14 +10,8 @@ import logging
10
  if 'history' not in st.session_state:
11
  st.session_state.history = []
12
 
13
- # OpenAI API key (ensure it is securely stored)
14
- # You can set the API key in your environment variables or a .env file
15
- openai_api_key = os.getenv("OPENAI_API_KEY")
16
-
17
- # Check if the API key is set
18
- if not openai_api_key:
19
- st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
20
- st.stop()
21
 
22
  # Step 1: Upload CSV data file (or use default)
23
  st.title("Natural Language to SQL Query App with Enhanced Insights")
@@ -43,57 +37,58 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
43
  valid_columns = list(data.columns)
44
  st.write(f"Valid columns: {valid_columns}")
45
 
46
- # Step 3: Set up the LLM Chains
47
- # SQL Generation Chain
48
- sql_template = """
49
- You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
50
-
51
- Ensure that:
52
-
53
- - You only use the columns provided.
54
- - When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
55
- - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
56
- - Do not apply 'COLLATE NOCASE' to numeric columns.
57
-
58
- If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated.
59
-
60
- Question: {question}
61
-
62
- Table name: {table_name}
63
-
64
- Valid columns: {columns}
65
-
66
- SQL Query:
67
- """
68
- sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
69
- llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens = 180)
70
- sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
71
-
72
- # Insights Generation Chain
73
- insights_template = """
74
- You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
75
-
76
- User's Question: {question}
77
-
78
- SQL Query Result:
79
- {result}
80
-
81
- Concise Analysis (max 200 words):
82
- """
83
- insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
84
- insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
85
-
86
- # General Insights and Recommendations Chain
87
- general_insights_template = """
88
- You are an expert data scientist. Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
89
 
90
- Dataset Summary:
91
- {dataset_summary}
 
 
 
 
 
 
 
 
 
92
 
93
- Concise Analysis and Recommendations (max 150 words):
94
- """
95
- general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
96
- general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
 
 
 
 
 
 
97
 
98
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
99
  def clean_sql_query(query):
@@ -117,42 +112,6 @@ def clean_sql_query(query):
117
  statements.append(''.join([str(t) for t in tokens]))
118
  return ' '.join(statements)
119
 
120
- # Function to classify user query
121
- def classify_query(question):
122
- """Classify the user query as either 'SQL' or 'INSIGHTS'."""
123
- classification_template = """
124
- You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
125
-
126
- Determine the appropriate category for the following user question.
127
-
128
- Question: "{question}"
129
-
130
- Category (SQL/INSIGHTS):
131
- """
132
- classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
133
- classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
134
- category = classification_chain.run({'question': question}).strip().upper()
135
- if category.startswith('SQL'):
136
- return 'SQL'
137
- else:
138
- return 'INSIGHTS'
139
-
140
- # Function to generate dataset summary
141
- def generate_dataset_summary(data):
142
- """Generate a summary of the dataset for general insights."""
143
- summary_template = """
144
- You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
145
-
146
- Dataset:
147
- {data}
148
-
149
- Dataset Summary:
150
- """
151
- summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
152
- summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
153
- summary = summary_chain.run({'data': data.head().to_string(index=False)})
154
- return summary
155
-
156
  # Define the callback function
157
  def process_input():
158
  user_prompt = st.session_state['user_input']
@@ -171,11 +130,7 @@ def process_input():
171
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
172
  elif category == 'SQL':
173
  columns = ', '.join(valid_columns)
174
- generated_sql = sql_generation_chain.run({
175
- 'question': user_prompt,
176
- 'table_name': table_name,
177
- 'columns': columns
178
- }).strip()
179
 
180
  if generated_sql.upper() == "NO_SQL":
181
  # Handle cases where no SQL should be generated
@@ -185,9 +140,7 @@ def process_input():
185
  dataset_summary = generate_dataset_summary(data)
186
 
187
  # Generate general insights and recommendations
188
- general_insights = general_insights_chain.run({
189
- 'dataset_summary': dataset_summary
190
- })
191
 
192
  # Append the assistant's insights to the history
193
  st.session_state.history.append({"role": "assistant", "content": general_insights})
@@ -208,10 +161,7 @@ def process_input():
208
  result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
209
 
210
  # Generate insights and recommendations based on the query result
211
- insights = insights_chain.run({
212
- 'question': user_prompt,
213
- 'result': result_str
214
- })
215
 
216
  # Append the assistant's insights to the history
217
  st.session_state.history.append({"role": "assistant", "content": insights})
@@ -226,9 +176,7 @@ def process_input():
226
  dataset_summary = generate_dataset_summary(data)
227
 
228
  # Generate general insights and recommendations
229
- general_insights = general_insights_chain.run({
230
- 'dataset_summary': dataset_summary
231
- })
232
 
233
  # Append the assistant's insights to the history
234
  st.session_state.history.append({"role": "assistant", "content": general_insights})
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
+ from transformers import pipeline
6
  import sqlparse
7
  import logging
8
 
 
10
  if 'history' not in st.session_state:
11
  st.session_state.history = []
12
 
13
+ # Load a pre-trained GPT-2 model from Hugging Face
14
+ llm = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
15
 
16
  # Step 1: Upload CSV data file (or use default)
17
  st.title("Natural Language to SQL Query App with Enhanced Insights")
 
37
  valid_columns = list(data.columns)
38
  st.write(f"Valid columns: {valid_columns}")
39
 
40
+ # Function to generate SQL query using Hugging Face model
41
+ def generate_sql_query(question, table_name, columns):
42
+ prompt = f"""
43
+ You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
44
+ Ensure that:
45
+ - You only use the columns provided.
46
+ - When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
47
+ - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
48
+ - Do not apply 'COLLATE NOCASE' to numeric columns.
49
+ If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated.
50
+ Question: {question}
51
+ Table name: {table_name}
52
+ Valid columns: {columns}
53
+ SQL Query:
54
+ """
55
+ response = llm(prompt, max_length=180)
56
+ return response[0]['generated_text'].strip()
57
+
58
+ # Function to generate insights using Hugging Face model
59
+ def generate_insights(question, result):
60
+ prompt = f"""
61
+ You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
62
+ User's Question: {question}
63
+ SQL Query Result:
64
+ {result}
65
+ Concise Analysis (max 200 words):
66
+ """
67
+ response = llm(prompt, max_length=150)
68
+ return response[0]['generated_text'].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # Function to classify user query as SQL or Insights
71
+ def classify_query(question):
72
+ prompt = f"""
73
+ You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
74
+ Determine the appropriate category for the following user question.
75
+ Question: "{question}"
76
+ Category (SQL/INSIGHTS):
77
+ """
78
+ response = llm(prompt, max_length=10)
79
+ category = response[0]['generated_text'].strip().upper()
80
+ return 'SQL' if 'SQL' in category else 'INSIGHTS'
81
 
82
+ # Function to generate dataset summary
83
+ def generate_dataset_summary(data):
84
+ summary_template = f"""
85
+ You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
86
+ Dataset:
87
+ {data.head().to_string(index=False)}
88
+ Dataset Summary:
89
+ """
90
+ response = llm(summary_template, max_length=150)
91
+ return response[0]['generated_text'].strip()
92
 
93
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
94
  def clean_sql_query(query):
 
112
  statements.append(''.join([str(t) for t in tokens]))
113
  return ' '.join(statements)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Define the callback function
116
  def process_input():
117
  user_prompt = st.session_state['user_input']
 
130
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
131
  elif category == 'SQL':
132
  columns = ', '.join(valid_columns)
133
+ generated_sql = generate_sql_query(user_prompt, table_name, columns)
 
 
 
 
134
 
135
  if generated_sql.upper() == "NO_SQL":
136
  # Handle cases where no SQL should be generated
 
140
  dataset_summary = generate_dataset_summary(data)
141
 
142
  # Generate general insights and recommendations
143
+ general_insights = generate_insights(user_prompt, dataset_summary)
 
 
144
 
145
  # Append the assistant's insights to the history
146
  st.session_state.history.append({"role": "assistant", "content": general_insights})
 
161
  result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
162
 
163
  # Generate insights and recommendations based on the query result
164
+ insights = generate_insights(user_prompt, result_str)
 
 
 
165
 
166
  # Append the assistant's insights to the history
167
  st.session_state.history.append({"role": "assistant", "content": insights})
 
176
  dataset_summary = generate_dataset_summary(data)
177
 
178
  # Generate general insights and recommendations
179
+ general_insights = generate_insights(user_prompt, dataset_summary)
 
 
180
 
181
  # Append the assistant's insights to the history
182
  st.session_state.history.append({"role": "assistant", "content": general_insights})