|
import os |
|
import time |
|
import openai |
|
import gradio as gr |
|
import polars as pl |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.vectorstores.azuresearch import AzureSearch |
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
openai.api_type = "azure" |
|
openai.api_version = "2023-03-15-preview" |
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
openai.api_base = os.getenv("OPENAI_API_BASE") |
|
vector_store_address = os.getenv("VECTOR_STORE_URL") |
|
vector_store_password = os.getenv("VECTOR_STORE_KEY") |
|
index_name = "motor-gm-search" |
|
|
|
df = pl.read_csv("year-make-model.csv") |
|
|
|
years = df["year"].unique().to_list() |
|
makes = df["make"].unique().to_list() |
|
models = df["model"].unique().to_list() |
|
|
|
with open("sys_prompt.txt", "r") as f: |
|
prompt = f.read() |
|
|
|
|
|
def embed(message): |
|
return embedder.encode([message])[0] |
|
|
|
|
|
|
|
embedder = SentenceTransformer("BAAI/bge-small-en") |
|
search = AzureSearch( |
|
azure_search_endpoint=vector_store_address, |
|
azure_search_key=vector_store_password, |
|
index_name=index_name, |
|
embedding_function=embed, |
|
) |
|
|
|
|
|
def filter_makes(year): |
|
df1 = df.filter(pl.col("year") == int(year)) |
|
choices = sorted(df1["make"].unique().to_list()) |
|
return gr.Dropdown.update(choices=choices, interactive=True) |
|
|
|
|
|
def filter_models(year, make): |
|
df1 = df.filter(pl.col("year") == int(year)) |
|
df1 = df1.filter(pl.col("make") == make) |
|
choices = sorted(df1["model"].unique().to_list()) |
|
return gr.Dropdown.update(choices=choices, interactive=True) |
|
|
|
|
|
def search_db(query, year, make, model, k=5, s_type="similarity"): |
|
filters = f"year eq {year} and make eq '{make}' and model eq '{model}'" |
|
|
|
res = [] |
|
if search_type == "hybrid": |
|
res = search.similarity_search(query, k, search_type=s_type, filters=filters) |
|
else: |
|
mult = 1 |
|
while len(res) < k or mult <= 16: |
|
res = search.similarity_search( |
|
query, 100 * mult, search_type=s_type, filters=filters |
|
) |
|
mult *= 2 |
|
res = res[:k] |
|
|
|
results = [] |
|
for r in res: |
|
results.append( |
|
{ |
|
"title": r.metadata["title"], |
|
"content": r.page_content, |
|
} |
|
) |
|
return str(results) |
|
|
|
|
|
def respond(message, history, year, make, model, search_type): |
|
if not year or not make or not model: |
|
msg = "Please select a year, make, and model." |
|
|
|
for i in range(len(msg)): |
|
time.sleep(0.02) |
|
yield msg[: i + 1] |
|
else: |
|
results = search_db(message, year, make, model, k=5, s_type=search_type) |
|
|
|
hist = [] |
|
hist.append( |
|
{ |
|
"role": "system", |
|
"content": prompt + results, |
|
} |
|
) |
|
hist.append( |
|
{ |
|
"role": "user", |
|
"content": f"Year: {year}\nMake: {make}\nModel: {model}\n\n{message}", |
|
} |
|
) |
|
model = "chatserver35turbo16k" |
|
res = openai.ChatCompletion.create( |
|
deployment_id=model, messages=hist, temperature=0.0, stream=True |
|
) |
|
msg = "" |
|
|
|
for chunk in res: |
|
if "content" in chunk["choices"][0]["delta"]: |
|
msg = msg + chunk["choices"][0]["delta"]["content"] |
|
yield msg |
|
|
|
|
|
with gr.Blocks( |
|
css="footer {visibility: hidden} #component-8 {height: 75vh !important} #component-9 {height: 70vh !important}" |
|
) as app: |
|
with gr.Row(): |
|
year = gr.Dropdown(years, label="Year") |
|
make = gr.Dropdown([], label="Make", interactive=False) |
|
model = gr.Dropdown([], label="Model", interactive=False) |
|
types = ["similarity", "hybrid"] |
|
search_type = gr.Dropdown(types, label="Search Type", value="hybrid") |
|
year.change(filter_makes, year, make) |
|
make.change(filter_models, [year, make], model) |
|
row = [year, make, model, search_type] |
|
gr.ChatInterface(respond, additional_inputs=row).queue() |
|
app.queue().launch(auth=("motor", "vectorsearch")) |
|
|