annavar's picture
Upload 4 files
1b143c7
raw
history blame contribute delete
No virus
4.4 kB
import gradio as gr
import os
import re
from dotenv import load_dotenv
from contextlib import redirect_stdout
from io import StringIO
from langchain import SQLDatabase, SQLDatabaseChain
from langchain.llms import AzureOpenAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
load_dotenv(os.getcwd() + "/.env")
llm = AzureOpenAI(
model_name=os.environ["OPENAI_MODEL_NAME"],
deployment_name=os.environ["OPENAI_DEPLOYMENT_NAME"],
temperature=0,
)
sqlite_db_path = "data/Chinook.db"
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
agent_executor = create_sql_agent(
llm=llm,
toolkit=SQLDatabaseToolkit(db=db, llm=llm),
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
def clear_input():
return "", "Hit 'Submit' to see output here"
def generate_output_of_db_chain(user_message):
print(user_message)
if not user_message:
print("Empty input")
yield "Please enter a messager before hitting Send!"
with redirect_stdout(StringIO()) as f:
db_chain.run(user_message)
s = f.getvalue()
#[6:]: skip first two \n and special tag from LangChain
s = s[6:].replace('\n', '<br/>')
yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)
def generate_output_of_db_agent(user_message):
if not user_message:
print("Empty input")
yield "Please enter a messager before hitting Send!"
return ""
with redirect_stdout(StringIO()) as f:
agent_executor.run(user_message)
s = f.getvalue()
#[6:]: skip first two \n and special tag from LangChain
s = s[6:].replace("\n", "<br/>")
yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)
custom_css = """
#banner-image {
display: block;
margin-left: auto;
margin-right: auto;
}
#chat-message {
font-size: 14px;
min-height: 300px;
}
"""
with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
gr.HTML("""<h1 align="center">LLM Mini-Series #4 πŸ’¬</h1>""")
with gr.Row():
with gr.Column():
gr.Markdown(
f"""
πŸ’» TODO Add some nice description text
"""
)
# normal SQL Chain
gr.HTML("""<h2 align="left">Using LangChain's SQLDatabaseChain</h2>""")
with gr.Row():
with gr.Column():
user_message = gr.Textbox(
placeholder="Enter your message here",
show_label=False,
elem_id="q-input",
)
with gr.Row():
clear_btn = gr.Button("Clear", elem_id="clear-btn", visible=True)
submit_btn = gr.Button("Submit", elem_id="submit-btn", visible=True)
with gr.Box():
output_field = gr.HTML(
value="Hit 'Submit' to see output here",
label="Output of model",
interactive=False,
)
# Agent-based approach
gr.HTML("""<h2 align="left">Using an agent-based approach with LangChain""")
with gr.Row():
with gr.Column():
user_message_agent = gr.Textbox(
placeholder="Enter your message here",
show_label=False,
elem_id="q-agent-input",
)
with gr.Row():
clear_agent_btn = gr.Button(
"Clear", elem_id="clear-agent-btn", visible=True
)
submit_agent_btn = gr.Button(
"Submit", elem_id="submit-agent-btn", visible=True
)
with gr.Box():
output_agent_field = gr.HTML(
value="Hit 'Submit' to see output here",
label="Output of model",
interactive=False,
)
clear_btn.click(clear_input, outputs=[user_message, output_field])
submit_btn.click(
generate_output_of_db_chain, inputs=[user_message], outputs=[output_field]
)
submit_agent_btn.click(
generate_output_of_db_agent,
inputs=[user_message_agent],
outputs=[output_agent_field],
)
clear_agent_btn.click(clear_input, outputs=[user_message_agent, output_agent_field])
demo.queue(concurrency_count=16).launch(debug=True) # , server_port=8080)