sql_chatbot / main.py
rolzy's picture
Upload folder using huggingface_hub
756f5ab
import os
import dotenv
import gradio as gr
# LLMs
import openai
from langchain.chat_models import AzureChatOpenAI
from langchain.llms import AzureOpenAI
from langchain.schema import AIMessage, HumanMessage
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
# Databases
import adal
import struct
from sqlalchemy.engine import URL
from sqlalchemy import create_engine, event
from langchain import SQLDatabase
# Load environment variables
dotenv.load_dotenv()
# Set OpenAI API settings from the .env file
openai.api_type = "azure"
openai.api_base = os.getenv("OPENAI_API_BASE")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_version = os.getenv('OPENAI_API_VERSION')
def get_token():
tenantId = os.getenv("DB_TENANT_ID")
clientId = os.getenv("DB_CLIENT_ID")
clientSecret = os.getenv("DB_CLIENT_SECRET")
authorityHostUrl = "https://login.microsoftonline.com"
authorityUrl = authorityHostUrl + "/" + tenantId
context = adal.AuthenticationContext(authorityUrl, api_version=None)
token = context.acquire_token_with_client_credentials("https://database.windows.net/", clientId, clientSecret)
tokenb = bytes(token["accessToken"], "UTF-8")
exptoken = b''
for i in tokenb:
exptoken += bytes({i})
exptoken += bytes(1)
return struct.pack("=i", len(exptoken)) + exptoken
def get_conn_url():
server = "sql-ae-dsdev-dna-hack01-pnpdebzuohv7y.database.windows.net"
database = 'dna-hack-db01'
return f"mssql+pyodbc://@{server}/{database}?driver=ODBC+Driver+17+for+SQL+Server"
def get_database():
# connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
conn_url = get_conn_url()
print(conn_url)
engine = create_engine(conn_url)
@event.listens_for(engine, "do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
# create token credential
token_struct = get_token()
# apply it to keyword arguments
SQL_COPT_SS_ACCESS_TOKEN = 1256
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
return SQLDatabase(engine = engine, schema="[dntmatrix]")
def get_llm():
return AzureChatOpenAI(temperature=1.0,
model_name='gpt-4',
deployment_name='gpt-4',
model_kwargs={
"engine": "gpt-4",
"api_key": openai.api_key,
"api_base": openai.api_base,
"api_type": openai.api_type,
"api_version": openai.api_version
}
)
def get_sql_chain():
db = get_database()
llm = get_llm()
return SQLDatabaseChain.from_llm(llm, db, verbose=True)
def get_sql_agent():
db = get_database()
llm = get_llm()
format_instruction = """
""" + FORMAT_INSTRUCTIONS
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
return create_sql_agent(
llm=llm,
toolkit=toolkit,
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
format_instructions=format_instruction
)
sql_agent = get_sql_agent()
def predict(message, history):
gpt_response = sql_agent.run(message)
return gpt_response
with gr.Blocks() as demo:
gr.ChatInterface(predict)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")