Spaces:
Build error
Build error
import os | |
import dotenv | |
import gradio as gr | |
# LLMs | |
import openai | |
from langchain.chat_models import AzureChatOpenAI | |
from langchain.llms import AzureOpenAI | |
from langchain.schema import AIMessage, HumanMessage | |
from langchain_experimental.sql import SQLDatabaseChain | |
from langchain.agents import create_sql_agent | |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
from langchain.agents.agent_types import AgentType | |
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS | |
# Databases | |
import adal | |
import struct | |
from sqlalchemy.engine import URL | |
from sqlalchemy import create_engine, event | |
from langchain import SQLDatabase | |
# Load environment variables | |
dotenv.load_dotenv() | |
# Set OpenAI API settings from the .env file | |
openai.api_type = "azure" | |
openai.api_base = os.getenv("OPENAI_API_BASE") | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
openai.api_version = os.getenv('OPENAI_API_VERSION') | |
def get_token(): | |
tenantId = os.getenv("DB_TENANT_ID") | |
clientId = os.getenv("DB_CLIENT_ID") | |
clientSecret = os.getenv("DB_CLIENT_SECRET") | |
authorityHostUrl = "https://login.microsoftonline.com" | |
authorityUrl = authorityHostUrl + "/" + tenantId | |
context = adal.AuthenticationContext(authorityUrl, api_version=None) | |
token = context.acquire_token_with_client_credentials("https://database.windows.net/", clientId, clientSecret) | |
tokenb = bytes(token["accessToken"], "UTF-8") | |
exptoken = b'' | |
for i in tokenb: | |
exptoken += bytes({i}) | |
exptoken += bytes(1) | |
return struct.pack("=i", len(exptoken)) + exptoken | |
def get_conn_url(): | |
server = "sql-ae-dsdev-dna-hack01-pnpdebzuohv7y.database.windows.net" | |
database = 'dna-hack-db01' | |
return f"mssql+pyodbc://@{server}/{database}?driver=ODBC+Driver+17+for+SQL+Server" | |
def get_database(): | |
# connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server" | |
conn_url = get_conn_url() | |
print(conn_url) | |
engine = create_engine(conn_url) | |
def provide_token(dialect, conn_rec, cargs, cparams): | |
# remove the "Trusted_Connection" parameter that SQLAlchemy adds | |
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") | |
# create token credential | |
token_struct = get_token() | |
# apply it to keyword arguments | |
SQL_COPT_SS_ACCESS_TOKEN = 1256 | |
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct} | |
return SQLDatabase(engine = engine, schema="[dntmatrix]") | |
def get_llm(): | |
return AzureChatOpenAI(temperature=1.0, | |
model_name='gpt-4', | |
deployment_name='gpt-4', | |
model_kwargs={ | |
"engine": "gpt-4", | |
"api_key": openai.api_key, | |
"api_base": openai.api_base, | |
"api_type": openai.api_type, | |
"api_version": openai.api_version | |
} | |
) | |
def get_sql_chain(): | |
db = get_database() | |
llm = get_llm() | |
return SQLDatabaseChain.from_llm(llm, db, verbose=True) | |
def get_sql_agent(): | |
db = get_database() | |
llm = get_llm() | |
format_instruction = """ | |
""" + FORMAT_INSTRUCTIONS | |
toolkit = SQLDatabaseToolkit(db=db, llm=llm) | |
return create_sql_agent( | |
llm=llm, | |
toolkit=toolkit, | |
verbose=True, | |
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
format_instructions=format_instruction | |
) | |
sql_agent = get_sql_agent() | |
def predict(message, history): | |
gpt_response = sql_agent.run(message) | |
return gpt_response | |
with gr.Blocks() as demo: | |
gr.ChatInterface(predict) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0") | |