Spaces:
Build error
Build error
| # Import necessary libraries and modules | |
| from dotenv import load_dotenv # For loading environment variables from .env | |
| from langchain_core.messages import AIMessage, HumanMessage # Message handling | |
| from langchain_core.prompts import ChatPromptTemplate # Prompt templates for generating responses | |
| from langchain_core.runnables import RunnablePassthrough # To chain operations | |
| from langchain_community.utilities import SQLDatabase # SQL database utility for LangChain | |
| from langchain_core.output_parsers import StrOutputParser # To parse outputs as strings | |
| # OpenAI model for chat (if used) | |
| from langchain_groq import ChatGroq # Groq model for chat (currently used) | |
| import streamlit as st # Streamlit for building the web app | |
| import os # To access environment variables | |
| # Load environment variables from the .env file (like API keys, database credentials) | |
| load_dotenv() | |
| # Function to initialize a connection to a MySQL database | |
| def init_database() -> SQLDatabase: | |
| try: | |
| # Load credentials from environment variables for better security | |
| user = os.getenv("DB_USER", "root") | |
| password = os.getenv("DB_PASSWORD", "admin") | |
| host = os.getenv("DB_HOST", "localhost") | |
| port = os.getenv("DB_PORT", "3306") | |
| database = os.getenv("DB_NAME", "Chinook") | |
| # Construct the database URI | |
| db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}" | |
| # Initialize and return the SQLDatabase instance | |
| return SQLDatabase.from_uri(db_uri) | |
| except Exception as e: | |
| st.error(f"Failed to connect to database: {e}") | |
| return None | |
| # Function to create a chain that generates SQL queries from user input and conversation history | |
| def get_sql_chain(db): | |
| # SQL prompt template | |
| template = """ | |
| You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database. | |
| Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account. | |
| <SCHEMA>{schema}</SCHEMA> | |
| Conversation History: {chat_history} | |
| Write only the SQL query and nothing else. | |
| Question: {question} | |
| SQL Query: | |
| """ | |
| # Create a prompt from the above template | |
| prompt = ChatPromptTemplate.from_template(template) | |
| # Initialize Groq model for generating SQL queries (can switch to OpenAI if needed) | |
| llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) | |
| # Helper function to get schema info from the database | |
| def get_schema(_): | |
| return db.get_table_info() | |
| # Chain of operations: | |
| # 1. Assign schema information from the database | |
| # 2. Use the AI model to generate a SQL query | |
| # 3. Parse the result into a string | |
| return ( | |
| RunnablePassthrough.assign(schema=get_schema) # Get schema info from the database | |
| | prompt # Generate SQL query from the prompt template | |
| | llm # Use Groq model to process the prompt and return a SQL query | |
| | StrOutputParser() # Parse the result as a string | |
| ) | |
| # Function to generate a response in natural language based on the SQL query result | |
| def get_response(user_query: str, db: SQLDatabase, chat_history: list): | |
| # Generate the SQL query using the chain | |
| sql_chain = get_sql_chain(db) | |
| # Prompt template for natural language response based on SQL query and result | |
| template = """ | |
| You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response. | |
| <SCHEMA>{schema}</SCHEMA> | |
| Conversation History: {chat_history} | |
| SQL Query: <SQL>{query}</SQL> | |
| User question: {question} | |
| SQL Response: {response} | |
| """ | |
| # Create a natural language response prompt | |
| prompt = ChatPromptTemplate.from_template(template) | |
| # Initialize Groq model (alternative: OpenAI) | |
| llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) | |
| # Build a chain: generate SQL query, run it on the database, generate a natural language response | |
| chain = ( | |
| RunnablePassthrough.assign(query=sql_chain).assign( | |
| schema=lambda _: db.get_table_info(), # Get schema info | |
| response=lambda vars: db.run(vars["query"]), # Run SQL query on the database | |
| ) | |
| | prompt # Use prompt to generate a natural language response | |
| | llm # Process prompt with Groq model | |
| | StrOutputParser() # Parse the final result as a string | |
| ) | |
| # Execute the chain and return the response | |
| return chain.invoke({ | |
| "question": user_query, | |
| "chat_history": chat_history, | |
| }) | |
| # Initialize the Streamlit session | |
| if "chat_history" not in st.session_state: | |
| # Initialize chat history with a welcome message from AI | |
| st.session_state.chat_history = [ | |
| AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."), | |
| ] | |
| # Set up the Streamlit web page configuration | |
| st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:") | |
| # Streamlit app title | |
| st.title("Chat with MySQL") | |
| # Sidebar for database connection settings | |
| with st.sidebar: | |
| st.subheader("Settings") | |
| st.write("Connect to your database and start chatting.") | |
| # Database connection input fields | |
| host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost")) | |
| port = st.text_input("Port", value=os.getenv("DB_PORT", "3306")) | |
| user = st.text_input("User", value=os.getenv("DB_USER", "root")) | |
| password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin")) | |
| database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook")) | |
| # Button to connect to the database | |
| if st.button("Connect"): | |
| with st.spinner("Connecting to database..."): | |
| # Initialize the database connection and store in session state | |
| db = init_database() | |
| if db: | |
| st.session_state.db = db | |
| st.success("Connected to the database!") | |
| else: | |
| st.error("Connection failed. Please check your settings.") | |
| # Display chat history | |
| for message in st.session_state.chat_history: | |
| if isinstance(message, AIMessage): | |
| # Display AI message | |
| with st.chat_message("AI"): | |
| st.markdown(message.content) | |
| elif isinstance(message, HumanMessage): | |
| # Display human message | |
| with st.chat_message("Human"): | |
| st.markdown(message.content) | |
| # Input field for user's message | |
| user_query = st.chat_input("Type a message...") | |
| if user_query and user_query.strip(): | |
| # Add user's query to the chat history | |
| st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
| # Display user's message in the chat | |
| with st.chat_message("Human"): | |
| st.markdown(user_query) | |
| # Generate and display AI's response based on the query | |
| with st.chat_message("AI"): | |
| response = get_response(user_query, st.session_state.db, st.session_state.chat_history) | |
| st.markdown(response) | |
| # Add AI's response to the chat history | |
| st.session_state.chat_history.append(AIMessage(content=response)) | |