import os from dotenv import load_dotenv from operator import itemgetter load_dotenv() db_user = os.getenv("db_user") db_password = os.getenv("db_password") db_host = os.getenv("db_host") db_name = os.getenv("db_name") import google.generativeai as genai from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.utilities.sql_database import SQLDatabase from langchain.chains import create_sql_query_chain from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool from langchain.memory import ChatMessageHistory from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from table_details import table_chain as select_table from prompts import final_prompt, answer_prompt import streamlit as st genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) llm = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0,convert_system_message_to_human=True) @st.cache_resource def get_chain(): #print("Creating chain") db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}") generate_query = create_sql_query_chain(llm, db,final_prompt) execute_query = QuerySQLDataBaseTool(db=db) rephrase_answer = answer_prompt | llm | StrOutputParser() # chain = generate_query | execute_query chain = ( RunnablePassthrough.assign(table_names_to_use=select_table) | RunnablePassthrough.assign(query=generate_query).assign( result=itemgetter("query") | execute_query ) | rephrase_answer ) return chain def create_history(messages): history = ChatMessageHistory() for message in messages: if message["role"] == "user": history.add_user_message(message["content"]) else: history.add_ai_message(message["content"]) return history def invoke_chain(question,messages): chain = get_chain() history = create_history(messages) response = chain.invoke({"question": question,"top_k":3,"messages":history.messages}) history.add_user_message(question) history.add_ai_message(response) return response