arithescientist commited on
Commit
d0ab6a9
·
verified ·
1 Parent(s): db88275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -23
app.py CHANGED
@@ -11,35 +11,51 @@ if 'history' not in st.session_state:
11
  st.session_state.history = []
12
 
13
  # OpenAI API key (ensure it is securely stored)
 
14
  openai_api_key = os.getenv("OPENAI_API_KEY")
15
 
 
 
 
 
 
16
  # Step 1: Upload CSV data file (or use default)
 
 
 
17
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
18
  if csv_file is None:
19
- data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
20
  st.write("Using default_data.csv file.")
 
21
  else:
22
  data = pd.read_csv(csv_file)
 
23
  st.write(f"Data Preview ({csv_file.name}):")
24
  st.dataframe(data.head())
25
 
26
  # Step 2: Load CSV data into a persistent SQLite database
27
  db_file = 'my_database.db'
28
  conn = sqlite3.connect(db_file)
29
- table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
30
  data.to_sql(table_name, conn, index=False, if_exists='replace')
31
 
32
  # SQL table metadata (for validation and schema)
33
  valid_columns = list(data.columns)
34
  st.write(f"Valid columns: {valid_columns}")
35
 
36
- # Step 3: Set up the LLM Chain to generate SQL queries
37
- template = """
 
38
  You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
39
 
40
  Ensure that:
 
41
  - You only use the columns provided.
42
- - String comparisons in the WHERE clause are case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
 
 
 
 
43
 
44
  Question: {question}
45
 
@@ -49,8 +65,93 @@ Valid columns: {columns}
49
 
50
  SQL Query:
51
  """
52
- prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
53
- sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Define the callback function
56
  def process_input():
@@ -61,31 +162,77 @@ def process_input():
61
  # Append user message to history
62
  st.session_state.history.append({"role": "user", "content": user_prompt})
63
 
64
- if "columns" in user_prompt.lower():
 
 
 
 
65
  assistant_response = f"The columns are: {', '.join(valid_columns)}"
66
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
67
- else:
68
  columns = ', '.join(valid_columns)
69
  generated_sql = sql_generation_chain.run({
70
  'question': user_prompt,
71
  'table_name': table_name,
72
  'columns': columns
73
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Debug: Display generated SQL query for inspection
76
- st.write(f"Generated SQL Query:\n{generated_sql}")
 
 
 
 
77
 
78
- # Attempt to execute SQL query and handle exceptions
79
- try:
80
- result = pd.read_sql_query(generated_sql, conn)
81
- assistant_response = f"Generated SQL Query:\n{generated_sql}"
82
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
83
- st.session_state.history.append({"role": "assistant", "content": result})
84
- except Exception as e:
85
- logging.error(f"An error occurred during SQL execution: {e}")
86
- assistant_response = f"Error executing SQL query: {e}"
87
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
 
 
89
  except Exception as e:
90
  logging.error(f"An error occurred: {e}")
91
  assistant_response = f"Error: {e}"
@@ -106,4 +253,4 @@ for message in st.session_state.history:
106
  st.markdown(f"**Assistant:** {message['content']}")
107
 
108
  # Place the input field at the bottom with the callback
109
- st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
11
  st.session_state.history = []
12
 
13
  # OpenAI API key (ensure it is securely stored)
14
+ # You can set the API key in your environment variables or a .env file
15
  openai_api_key = os.getenv("OPENAI_API_KEY")
16
 
17
+ # Check if the API key is set
18
+ if not openai_api_key:
19
+ st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
20
+ st.stop()
21
+
22
  # Step 1: Upload CSV data file (or use default)
23
+ st.title("Natural Language to SQL Query App with Enhanced Insights")
24
+ st.write("Upload a CSV file to get started, or use the default dataset.")
25
+
26
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
27
  if csv_file is None:
28
+ data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
29
  st.write("Using default_data.csv file.")
30
+ table_name = "default_table"
31
  else:
32
  data = pd.read_csv(csv_file)
33
+ table_name = csv_file.name.split('.')[0]
34
  st.write(f"Data Preview ({csv_file.name}):")
35
  st.dataframe(data.head())
36
 
37
  # Step 2: Load CSV data into a persistent SQLite database
38
  db_file = 'my_database.db'
39
  conn = sqlite3.connect(db_file)
 
40
  data.to_sql(table_name, conn, index=False, if_exists='replace')
41
 
42
  # SQL table metadata (for validation and schema)
43
  valid_columns = list(data.columns)
44
  st.write(f"Valid columns: {valid_columns}")
45
 
46
+ # Step 3: Set up the LLM Chains
47
+ # SQL Generation Chain
48
+ sql_template = """
49
  You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
50
 
51
  Ensure that:
52
+
53
  - You only use the columns provided.
54
+ - When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
55
+ - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
56
+ - Do not apply 'COLLATE NOCASE' to numeric columns.
57
+
58
+ If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated.
59
 
60
  Question: {question}
61
 
 
65
 
66
  SQL Query:
67
  """
68
+ sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
69
+ llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens = 180)
70
+ sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
71
+
72
+ # Insights Generation Chain
73
+ insights_template = """
74
+ You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
75
+
76
+ User's Question: {question}
77
+
78
+ SQL Query Result:
79
+ {result}
80
+
81
+ Concise Analysis (max 200 words):
82
+ """
83
+ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
84
+ insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
85
+
86
+ # General Insights and Recommendations Chain
87
+ general_insights_template = """
88
+ You are an expert data scientist. Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
89
+
90
+ Dataset Summary:
91
+ {dataset_summary}
92
+
93
+ Concise Analysis and Recommendations (max 150 words):
94
+ """
95
+ general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
96
+ general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
97
+
98
+ # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
99
+ def clean_sql_query(query):
100
+ """Removes incorrect usage of COLLATE NOCASE from the SQL query."""
101
+ parsed = sqlparse.parse(query)
102
+ statements = []
103
+ for stmt in parsed:
104
+ tokens = []
105
+ idx = 0
106
+ while idx < len(stmt.tokens):
107
+ token = stmt.tokens[idx]
108
+ if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'):
109
+ # Check if the next token is 'NOCASE'
110
+ next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None
111
+ if next_token and next_token.value.upper() == 'NOCASE':
112
+ # Skip 'COLLATE' and 'NOCASE' tokens
113
+ idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE'
114
+ continue
115
+ tokens.append(token)
116
+ idx += 1
117
+ statements.append(''.join([str(t) for t in tokens]))
118
+ return ' '.join(statements)
119
+
120
+ # Function to classify user query
121
+ def classify_query(question):
122
+ """Classify the user query as either 'SQL' or 'INSIGHTS'."""
123
+ classification_template = """
124
+ You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
125
+
126
+ Determine the appropriate category for the following user question.
127
+
128
+ Question: "{question}"
129
+
130
+ Category (SQL/INSIGHTS):
131
+ """
132
+ classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
133
+ classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
134
+ category = classification_chain.run({'question': question}).strip().upper()
135
+ if category.startswith('SQL'):
136
+ return 'SQL'
137
+ else:
138
+ return 'INSIGHTS'
139
+
140
+ # Function to generate dataset summary
141
+ def generate_dataset_summary(data):
142
+ """Generate a summary of the dataset for general insights."""
143
+ summary_template = """
144
+ You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
145
+
146
+ Dataset:
147
+ {data}
148
+
149
+ Dataset Summary:
150
+ """
151
+ summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
152
+ summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
153
+ summary = summary_chain.run({'data': data.head().to_string(index=False)})
154
+ return summary
155
 
156
  # Define the callback function
157
  def process_input():
 
162
  # Append user message to history
163
  st.session_state.history.append({"role": "user", "content": user_prompt})
164
 
165
+ # Classify the user query
166
+ category = classify_query(user_prompt)
167
+ logging.info(f"User query classified as: {category}")
168
+
169
+ if "COLUMNS" in user_prompt.upper():
170
  assistant_response = f"The columns are: {', '.join(valid_columns)}"
171
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
172
+ elif category == 'SQL':
173
  columns = ', '.join(valid_columns)
174
  generated_sql = sql_generation_chain.run({
175
  'question': user_prompt,
176
  'table_name': table_name,
177
  'columns': columns
178
+ }).strip()
179
+
180
+ if generated_sql.upper() == "NO_SQL":
181
+ # Handle cases where no SQL should be generated
182
+ assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
183
+
184
+ # Generate dataset summary
185
+ dataset_summary = generate_dataset_summary(data)
186
+
187
+ # Generate general insights and recommendations
188
+ general_insights = general_insights_chain.run({
189
+ 'dataset_summary': dataset_summary
190
+ })
191
+
192
+ # Append the assistant's insights to the history
193
+ st.session_state.history.append({"role": "assistant", "content": general_insights})
194
+ else:
195
+ # Clean the SQL query
196
+ cleaned_sql = clean_sql_query(generated_sql)
197
+ logging.info(f"Generated SQL Query: {cleaned_sql}")
198
+
199
+ # Attempt to execute SQL query and handle exceptions
200
+ try:
201
+ result = pd.read_sql_query(cleaned_sql, conn)
202
 
203
+ if result.empty:
204
+ assistant_response = "The query returned no results. Please try a different question."
205
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
206
+ else:
207
+ # Convert the result to a string for the insights prompt
208
+ result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
209
 
210
+ # Generate insights and recommendations based on the query result
211
+ insights = insights_chain.run({
212
+ 'question': user_prompt,
213
+ 'result': result_str
214
+ })
215
+
216
+ # Append the assistant's insights to the history
217
+ st.session_state.history.append({"role": "assistant", "content": insights})
218
+ # Append the result DataFrame to the history
219
+ st.session_state.history.append({"role": "assistant", "content": result})
220
+ except Exception as e:
221
+ logging.error(f"An error occurred during SQL execution: {e}")
222
+ assistant_response = f"Error executing SQL query: {e}"
223
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
224
+ else: # INSIGHTS category
225
+ # Generate dataset summary
226
+ dataset_summary = generate_dataset_summary(data)
227
+
228
+ # Generate general insights and recommendations
229
+ general_insights = general_insights_chain.run({
230
+ 'dataset_summary': dataset_summary
231
+ })
232
 
233
+ # Append the assistant's insights to the history
234
+ st.session_state.history.append({"role": "assistant", "content": general_insights})
235
+
236
  except Exception as e:
237
  logging.error(f"An error occurred: {e}")
238
  assistant_response = f"Error: {e}"
 
253
  st.markdown(f"**Assistant:** {message['content']}")
254
 
255
  # Place the input field at the bottom with the callback
256
+ st.text_input("Enter your message:", key='user_input', on_change=process_input)