NL_2_SQL_Data_Analysis_Chatbot / langchain_utils.py
ramhemanth580's picture
Update langchain_utils.py
6457675 verified
import os
from dotenv import load_dotenv
from operator import itemgetter
load_dotenv()
db_user = os.getenv("db_user")
db_password = os.getenv("db_password")
db_host = os.getenv("db_host")
db_name = os.getenv("db_name")
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.memory import ChatMessageHistory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from table_details import table_chain as select_table
from prompts import final_prompt, answer_prompt
import streamlit as st
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
llm = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0,convert_system_message_to_human=True)
@st.cache_resource
def get_chain():
#print("Creating chain")
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
generate_query = create_sql_query_chain(llm, db,final_prompt)
execute_query = QuerySQLDataBaseTool(db=db)
rephrase_answer = answer_prompt | llm | StrOutputParser()
# chain = generate_query | execute_query
chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
result=itemgetter("query") | execute_query
)
| rephrase_answer
)
return chain
def create_history(messages):
history = ChatMessageHistory()
for message in messages:
if message["role"] == "user":
history.add_user_message(message["content"])
else:
history.add_ai_message(message["content"])
return history
def invoke_chain(question,messages):
chain = get_chain()
history = create_history(messages)
response = chain.invoke({"question": question,"top_k":3,"messages":history.messages})
history.add_user_message(question)
history.add_ai_message(response)
return response