veda_bot_2.0 / vector_tool.py
samlonka
'credentials'
0e51c38
raw
history blame
2.71 kB
import os
import time
import pickle
import streamlit as st
from dotenv import load_dotenv
from pinecone import Pinecone, ServerlessSpec
from utils import load_pickle, initialize_embedding_model
from langchain_community.retrievers import BM25Retriever
from langchain_pinecone import PineconeVectorStore
from langchain.retrievers import EnsembleRetriever
from langchain.tools.retriever import create_retriever_tool
# Load .env file
load_dotenv()
# Constants
INDEX_NAME = "veda-index-v2"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
# Initialize Pinecone client
os.environ['PINECONE_API_KEY'] = os.getenv("PINECONE_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
pc = Pinecone(api_key=PINECONE_API_KEY)
#@st.cache_resource
def create_or_load_index():
# Check if index already exists
if INDEX_NAME not in pc.list_indexes().names():
# Create index if it does not exist
pc.create_index(
INDEX_NAME,
dimension=1024,
metric='dotproduct',
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
# Wait for index to be initialized
while not pc.describe_index(INDEX_NAME).status['ready']:
time.sleep(1)
# Connect to index
return pc.Index(INDEX_NAME)
# Load documents
docs = load_pickle("ramana_docs_ids.pkl")
# Initialize embedding model
embedding = initialize_embedding_model(MODEL_NAME)
# Create or load index
index = create_or_load_index()
# Initialize BM25 retriever
bm25_retriever = BM25Retriever.from_texts(
[text['document'].page_content for text in docs],
metadatas=[text['document'].metadata for text in docs]
)
bm25_retriever.k = 2
# Switch back to normal index for LangChain
vector_store = PineconeVectorStore(index, embedding)
retriever = vector_store.as_retriever(search_type="mmr")
# Initialize the ensemble retriever
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, retriever], weights=[0.2, 0.8]
)
vector_tools = create_retriever_tool(
retriever = ensemble_retriever,
name = "vector_retrieve",
description="Search and return documents related user query from the vector index.",
)
from langchain import hub
prompt = hub.pull("hwchase17/openai-tools-agent")
prompt.messages
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_openai import ChatOpenAI
import streamlit as st
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
#load llm model
llm_AI4 = ChatOpenAI(model="gpt-4-1106-preview", temperature=0)
agent = create_openai_tools_agent(llm_AI4, [vector_tools], prompt)
agent_executor = AgentExecutor(agent=agent, tools=[vector_tools])