Spaces:
Sleeping
Sleeping
import os | |
import time | |
import pickle | |
import streamlit as st | |
from dotenv import load_dotenv | |
from pinecone import Pinecone, ServerlessSpec | |
from utils import load_pickle, initialize_embedding_model | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_pinecone import PineconeVectorStore | |
from langchain.retrievers import EnsembleRetriever | |
from langchain.tools.retriever import create_retriever_tool | |
# Load .env file | |
load_dotenv() | |
# Constants | |
INDEX_NAME = "veda-index-v2" | |
MODEL_NAME = "BAAI/bge-large-en-v1.5" | |
# Initialize Pinecone client | |
os.environ['PINECONE_API_KEY'] = os.getenv("PINECONE_API_KEY") | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
#@st.cache_resource | |
def create_or_load_index(): | |
# Check if index already exists | |
if INDEX_NAME not in pc.list_indexes().names(): | |
# Create index if it does not exist | |
pc.create_index( | |
INDEX_NAME, | |
dimension=1024, | |
metric='dotproduct', | |
spec=ServerlessSpec( | |
cloud="aws", | |
region="us-east-1" | |
) | |
) | |
# Wait for index to be initialized | |
while not pc.describe_index(INDEX_NAME).status['ready']: | |
time.sleep(1) | |
# Connect to index | |
return pc.Index(INDEX_NAME) | |
# Load documents | |
docs = load_pickle("ramana_docs_ids.pkl") | |
# Initialize embedding model | |
embedding = initialize_embedding_model(MODEL_NAME) | |
# Create or load index | |
index = create_or_load_index() | |
# Initialize BM25 retriever | |
bm25_retriever = BM25Retriever.from_texts( | |
[text['document'].page_content for text in docs], | |
metadatas=[text['document'].metadata for text in docs] | |
) | |
bm25_retriever.k = 2 | |
# Switch back to normal index for LangChain | |
vector_store = PineconeVectorStore(index, embedding) | |
retriever = vector_store.as_retriever(search_type="mmr") | |
# Initialize the ensemble retriever | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, retriever], weights=[0.2, 0.8] | |
) | |
vector_tools = create_retriever_tool( | |
retriever = ensemble_retriever, | |
name = "vector_retrieve", | |
description="Search and return documents related user query from the vector index.", | |
) | |
from langchain import hub | |
prompt = hub.pull("hwchase17/openai-tools-agent") | |
prompt.messages | |
from langchain.agents import AgentExecutor, create_openai_tools_agent | |
from langchain_openai import ChatOpenAI | |
import streamlit as st | |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
#load llm model | |
llm_AI4 = ChatOpenAI(model="gpt-4-1106-preview", temperature=0) | |
agent = create_openai_tools_agent(llm_AI4, [vector_tools], prompt) | |
agent_executor = AgentExecutor(agent=agent, tools=[vector_tools]) | |