arithescientist commited on
Commit
d9d0b05
·
verified ·
1 Parent(s): 6dd2b20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -87
app.py CHANGED
@@ -1,15 +1,14 @@
1
- import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
- from langchain.agents import create_sql_agent
7
- from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
- from langchain.agents.agent_types import AgentType
9
  from langchain.llms import OpenAI
10
- from langchain.sql_database import SQLDatabase
11
  from langchain.chat_models import ChatOpenAI
12
- from langchain.evaluation import load_evaluator
 
 
 
13
 
14
  # Initialize logging
15
  logging.basicConfig(level=logging.INFO)
@@ -52,96 +51,20 @@ engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name
52
  # Initialize the LLM
53
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
54
 
55
- # Step 3: Create the agent
56
- toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
57
-
58
- sql_agent = create_sql_agent(
59
- llm=llm,
60
- toolkit=toolkit,
61
- verbose=True,
62
- agent_type=AgentType.OPENAI_FUNCTIONS,
63
- max_iterations=5
64
- )
65
 
66
  # Step 4: Define the callback function
67
  def process_input():
68
- user_prompt = st.session_state['user_input']
69
-
70
- if user_prompt:
71
- try:
72
- # Append user message to history
73
- st.session_state.history.append({"role": "user", "content": user_prompt})
74
-
75
- # Use the agent to generate the SQL query and get the response
76
- with st.spinner("Processing..."):
77
- response = sql_agent.run(user_prompt)
78
-
79
- # Check if the response contains a SQL query
80
- if "```sql" in response:
81
- # Extract the SQL query
82
- start_index = response.find("```sql") + len("```sql")
83
- end_index = response.find("```", start_index)
84
- sql_query = response[start_index:end_index].strip()
85
- else:
86
- # If no SQL code is found, assume the entire response is the SQL query
87
- sql_query = response.strip()
88
-
89
- logging.info(f"Generated SQL Query: {sql_query}")
90
-
91
- # Attempt to execute SQL query and handle exceptions
92
- try:
93
- result = pd.read_sql_query(sql_query, conn)
94
-
95
- if result.empty:
96
- assistant_response = "The query returned no results. Please try a different question."
97
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
98
- else:
99
- # Limit the result to first 10 rows for display
100
- result_display = result.head(10)
101
- st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
102
- st.session_state.history.append({"role": "assistant", "content": result_display})
103
-
104
- # Generate insights based on the query result
105
- insights_template = """
106
- You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
107
-
108
- User's Question: {question}
109
-
110
- SQL Query Result:
111
- {result}
112
-
113
- Concise Analysis:
114
- """
115
- insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
116
- insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
117
-
118
- result_str = result_display.to_string(index=False)
119
- insights = insights_chain.run({'question': user_prompt, 'result': result_str})
120
-
121
- # Append the assistant's insights to the history
122
- st.session_state.history.append({"role": "assistant", "content": insights})
123
- except Exception as e:
124
- logging.error(f"An error occurred during SQL execution: {e}")
125
- assistant_response = f"Error executing SQL query: {e}"
126
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
127
- except Exception as e:
128
- logging.error(f"An error occurred: {e}")
129
- assistant_response = f"Error: {e}"
130
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
131
-
132
- # Reset user input
133
- st.session_state['user_input'] = ''
134
 
135
  # Step 5: Display conversation history
136
  for message in st.session_state.history:
137
  if message['role'] == 'user':
138
  st.markdown(f"**User:** {message['content']}")
139
  elif message['role'] == 'assistant':
140
- if isinstance(message['content'], pd.DataFrame):
141
- st.markdown("**Assistant:** Query Results:")
142
- st.dataframe(message['content'])
143
- else:
144
- st.markdown(f"**Assistant:** {message['content']}")
145
 
146
  # Input field
147
  st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
1
+ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
 
 
 
6
  from langchain.llms import OpenAI
 
7
  from langchain.chat_models import ChatOpenAI
8
+ from langchain.chains import SQLDatabaseChain
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.chains import LLMChain
11
+ from langchain.sql_database import SQLDatabase
12
 
13
  # Initialize logging
14
  logging.basicConfig(level=logging.INFO)
 
51
  # Initialize the LLM
52
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
53
 
54
+ # Initialize the SQLDatabaseChain
55
+ sql_chain = SQLDatabaseChain(llm=llm, database=engine, verbose=True)
 
 
 
 
 
 
 
 
56
 
57
  # Step 4: Define the callback function
58
  def process_input():
59
+ # (Use the updated process_input function provided earlier)
60
+ # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Step 5: Display conversation history
63
  for message in st.session_state.history:
64
  if message['role'] == 'user':
65
  st.markdown(f"**User:** {message['content']}")
66
  elif message['role'] == 'assistant':
67
+ st.markdown(f"**Assistant:** {message['content']}")
 
 
 
 
68
 
69
  # Input field
70
  st.text_input("Enter your message:", key='user_input', on_change=process_input)