Spaces:
Runtime error
Runtime error
| import os | |
| from dotenv import load_dotenv | |
| from operator import itemgetter | |
| load_dotenv() | |
| db_user = os.getenv("db_user") | |
| db_password = os.getenv("db_password") | |
| db_host = os.getenv("db_host") | |
| db_name = os.getenv("db_name") | |
| import google.generativeai as genai | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.utilities.sql_database import SQLDatabase | |
| from langchain.chains import create_sql_query_chain | |
| from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
| from langchain.memory import ChatMessageHistory | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from table_details import table_chain as select_table | |
| from prompts import final_prompt, answer_prompt | |
| import streamlit as st | |
| genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) | |
| llm = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0,convert_system_message_to_human=True) | |
| def get_chain(): | |
| #print("Creating chain") | |
| db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}") | |
| generate_query = create_sql_query_chain(llm, db,final_prompt) | |
| execute_query = QuerySQLDataBaseTool(db=db) | |
| rephrase_answer = answer_prompt | llm | StrOutputParser() | |
| # chain = generate_query | execute_query | |
| chain = ( | |
| RunnablePassthrough.assign(table_names_to_use=select_table) | | |
| RunnablePassthrough.assign(query=generate_query).assign( | |
| result=itemgetter("query") | execute_query | |
| ) | |
| | rephrase_answer | |
| ) | |
| return chain | |
| def create_history(messages): | |
| history = ChatMessageHistory() | |
| for message in messages: | |
| if message["role"] == "user": | |
| history.add_user_message(message["content"]) | |
| else: | |
| history.add_ai_message(message["content"]) | |
| return history | |
| def invoke_chain(question,messages): | |
| chain = get_chain() | |
| history = create_history(messages) | |
| response = chain.invoke({"question": question,"top_k":3,"messages":history.messages}) | |
| history.add_user_message(question) | |
| history.add_ai_message(response) | |
| return response |