Spaces:
Runtime error
Runtime error
File size: 4,272 Bytes
cbdf795 c58b4cd cbdf795 4f6811a cbdf795 4f6811a c58b4cd 4f6811a cbdf795 c58b4cd cbdf795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
#!/usr/bin/env python
import json
import logging
import os
import sys
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (ServiceContext, StorageContext,
load_index_from_storage, set_global_service_context)
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, PromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.retrievers import RecursiveRetriever
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
def get_embed_model():
model_kwargs = {'device': 'cpu'}
if torch.cuda.is_available():
model_kwargs['device'] = 'cuda'
if torch.backends.mps.is_available():
model_kwargs['device'] = 'mps'
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
print("Loading model...")
try:
model_norm = HuggingFaceEmbeddings(
model_name="thenlper/gte-small",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
except Exception as exception:
print(f"Model not found. Loading fake model...{exception}")
exit()
print("Model loaded.")
return model_norm
QA_TEMPLATE = """
You are a chatbot, able to have normal interactions as well as respond to question about my Ford F150.
Below are excerpts from my F150's user manual. You must only use the information in the context below to formulate your response.
If there is not enough information to formulate a response, you must respond with: "I'm sorry, I can't find the answer to your question."
{context_str}
{query_str}
"""
def main():
embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
AWS_KEY = "AKIAWCUHDQXX3H7PPRXN"
AWS_SECRET = "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"
s3 = s3fs.S3FileSystem(
key=AWS_KEY,
secret=AWS_SECRET,
)
titles = s3.ls("f150-user-manual/recursive-agent/")
titles = list(map(lambda x: x.split("/")[-1], titles))
agents = {}
for title in titles[:5]:
if(title == "vector_index"):
continue
print(title)
# build vector index
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3)
vector_index = load_index_from_storage(storage_context)
# define query engines
vector_query_engine = vector_index.as_query_engine(
similarity_top_k=2,
verbose=True,
)
agents[title] = vector_query_engine
print(f"Agents: {len(agents)}")
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3)
top_level_vector_index = load_index_from_storage(storage_context)
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever},
query_engine_dict=agents,
verbose=True,
query_response_tmpl="{response}"
)
# response_synthesizer = get_response_synthesizer(
# response_mode="compact_accumulate",
# )
# query_engine = RetrieverQueryEngine.from_args(
# recursive_retriever,
# similarity_top_k=1,
# response_synthesizer=response_synthesizer,
# service_context=service_context,
# )
while True:
try:
# Read
user_input = input(">>> ")
# Evaluate and Print
if user_input == 'exit':
break
else:
response = recursive_retriever.retrieve(user_input)
print(response[0].get_text())
except Exception as e:
# Handle exceptions
print("Error:", e)
if __name__ == '__main__':
main() |