f150 / src /app.py
Adrian Cowham
removed debug limit
f3e6f47
raw
history blame
4.46 kB
import json
import logging
import os
import sys
from threading import Lock
import gradio as gr
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (ServiceContext, StorageContext,
load_index_from_storage, set_global_service_context)
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, PromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.retrievers import RecursiveRetriever
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url
def get_embed_model():
model_kwargs = {'device': 'cpu'}
if torch.cuda.is_available():
model_kwargs['device'] = 'cuda'
if torch.backends.mps.is_available():
model_kwargs['device'] = 'mps'
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
print("Loading model...")
try:
model_norm = HuggingFaceEmbeddings(
model_name="thenlper/gte-small",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
except Exception as exception:
print(f"Model not found. Loading fake model...{exception}")
exit()
print("Model loaded.")
return model_norm
embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
s3 = s3fs.S3FileSystem(
key=os.environ["AWS_CANONICAL_KEY"],
secret=os.environ["AWS_CANONICAL_SECRET"],
)
titles = s3.ls("f150-user-manual/recursive-agent/")
titles = list(map(lambda x: x.split("/")[-1], titles))
agents = {}
for title in titles:
if(title == "vector_index"):
continue
print(title)
# build vector index
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3)
vector_index = load_index_from_storage(storage_context)
# define query engines
vector_query_engine = vector_index.as_query_engine(
similarity_top_k=2,
verbose=True
)
agents[title] = vector_query_engine
print(f"Agents: {len(agents)}")
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3)
top_level_vector_index = load_index_from_storage(storage_context)
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever},
query_engine_dict=agents,
verbose=True,
query_response_tmpl="{response}"
)
lock = Lock()
def predict(message):
print(message)
lock.acquire()
try:
output = recursive_retriever.retrieve(message)[0]
output = output.get_text()
except Exception as e:
print(e)
raise e
finally:
lock.release()
return output
def getanswer(question, history):
print("getting answer")
if hasattr(history, "value"):
history = history.value
if hasattr(question, "value"):
question = question.value
history = history or []
lock.acquire()
try:
output = recursive_retriever.retrieve(question)[0]
history.append((question, output.get_text()))
except Exception as e:
raise e
finally:
lock.release()
return history, history, gr.update(value="")
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0.75):
with gr.Row():
gr.Markdown("<h1>F150 User Manual</h1>")
chatbot = gr.Chatbot(elem_id="chatbot").style(height=600)
with gr.Row():
message = gr.Textbox(
label="",
placeholder="F150 User Manual",
lines=1,
)
with gr.Row():
submit = gr.Button(value="Send", variant="primary", scale=1)
state = gr.State()
submit.click(getanswer, inputs=[message, state], outputs=[chatbot, state, message])
message.submit(getanswer, inputs=[message, state], outputs=[chatbot, state, message])
predictBtn = gr.Button(value="Predict", visible=False)
predictBtn.click(predict, inputs=[message], outputs=[message])
demo.launch(debug=True)