ComfyKnowledgeGraph / embed_handler.py
jibinmathew's picture
Upload 42 files
eb957df verified
raw
history blame
2.32 kB
from dotenv import load_dotenv
from llama_index.core import VectorStoreIndex
import os
from llama_index.llms.openai import OpenAI
from llama_index.core import StorageContext, Settings, load_index_from_storage
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from retrieve import get_latest_dir
from datetime import datetime
load_dotenv()
Settings.llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
Settings.chunk_size = 2048
Settings.chunk_overlap = 24
def create_embedding():
"""
Create an embedding from the given directory.
Returns:
VectorStoreIndex: The index of the embedding from docs in the directory.
"""
output_dir = os.getenv("EMBEDDING_DIR")
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
embedding_path = f"{output_dir}/{timestamp}"
documents = SimpleDirectoryReader(os.getenv("PROD_SPEC_DIR")).load_data()
index = VectorStoreIndex.from_documents(documents, show_progress=True)
index.storage_context.persist(persist_dir=embedding_path)
return index
def load_embedding():
"""
Load the latest embedding from the directory.
Returns:
VectorStoreIndex: The index of the embedding from the latest directory.
"""
PERSIST_DIR = get_latest_dir(os.getenv("EMBEDDING_DIR"))
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
return index
def query_rag_qa(rag_index, query, search_level):
"""
Query the RAG model for a given query.
Args:
rag_index (VectorStoreIndex): The RAG model index.
query (str): The query to ask the RAG model.
search_level (int): The max search level to use for the RAG model.
Returns:
tuple: The response, nodes, and reference text from the RAG model.
"""
myretriever = rag_index.as_retriever(
include_text=True,
similarity_top_k=search_level,
)
query_engine = rag_index.as_query_engine(
sub_retrievers=[
myretriever,
],
include_text=True,
similarity_top_k=search_level,
)
response = query_engine.query(query)
nodes = myretriever.retrieve(query)
reference_text = []
for node in nodes:
reference_text.append(node.text)
return response, nodes, reference_text