maya-persistence / src /chromaIntf.py
anubhav77's picture
move to gemini-1.5-flash
9a1d7f1
import sys
try:
import pysqlite3
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
except:
pass
import chromadb
from langchain.vectorstores import Chroma
# from chromadb.api.fastapi import requests
from langchain.schema import Document
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.chroma import ChromaTranslator
from llm.llmFactory import LLMFactory
from datetime import datetime
import baseInfra.dropbox_handler as dbh
from baseInfra.dbInterface import DbInterface
from uuid import UUID
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging, asyncio
logger = logging.getLogger("root")
class myChromaTranslator(ChromaTranslator):
allowed_operators = ["$and", "$or"]
"""Subset of allowed logical operators."""
allowed_comparators = [
"$eq",
"$ne",
"$gt",
"$gte",
"$lt",
"$lte",
"$contains",
"$not_contains",
"$in",
"$nin",
]
class ChromaIntf:
def __init__(self):
self.db_interface = DbInterface()
model_name = "BAAI/bge-large-en-v1.5"
encode_kwargs = {
"normalize_embeddings": True
} # set True to compute cosine similarity
self.embedding = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={"device": "cpu"},
encode_kwargs=encode_kwargs,
)
self.persist_db_directory = "db"
self.persist_docs_directory = "persistence-docs"
self.logger_file = "persistence.log"
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(dbh.restoreFolder(self.persist_db_directory))
loop.run_until_complete(dbh.restoreFolder(self.persist_docs_directory))
except:
print("Probably folder doesn't exist as it is brand new setup")
docs = [
Document(
page_content="this is test doc",
metadata={
"timestamp": 1696743148.474055,
"ID": "2000-01-01 15:57:11::664165-test",
"source": "test",
},
id="2000-01-01 15:57:11::664165-test",
),
]
self.vectorstore = Chroma.from_documents(
documents=docs,
embedding=self.embedding,
persist_directory=self.persist_db_directory,
)
# self.vectorstore._client.
# timestamp --> time when added
# source --> notes/references/web/youtube/book/conversation, default conversation
# title --> of document , will be conversation when source is conversation, default blank
# author --> will default to blank
# "Year": 2024,
# "Month": 1,
# "Day": 3,
# "Hour": 11,
# "Minute": 29
self.metadata_field_info = [
AttributeInfo(
name="timestamp",
description="Python datetime.timestamp of the document in isoformat, should not be used for query",
type="str",
),
AttributeInfo(
name="Year",
description="Year from the date when the entry was added in YYYY format",
type="int",
),
AttributeInfo(
name="Month",
description="Month from the date when the entry was added it is from 1-12",
type="int",
),
AttributeInfo(
name="Day",
description="Day of month from the date-time stamp when the entry was added, it is from 1-31",
type="int",
),
AttributeInfo(
name="Hour",
description="Hour from the timestamp when the entry was added",
type="int",
),
AttributeInfo(
name="Minute",
description="Minute from the timestamp when the entry was added",
type="int",
),
AttributeInfo(
name="source",
description="Type of entry",
type="string or list[string]",
),
AttributeInfo(
name="title",
description="Title or Subject of the entry",
type="string",
),
AttributeInfo(
name="author",
description="Author of the entry",
type="string",
),
]
self.document_content_description = (
"Information to store for retrival from LLM based chatbot"
)
lf = LLMFactory()
# self.llm=lf.get_llm("executor2")
self.llm = lf.get_llm("executor3")
self.retriever = SelfQueryRetriever.from_llm(
self.llm,
self.vectorstore,
self.document_content_description,
self.metadata_field_info,
structured_query_translator=ChromaTranslator(),
verbose=True,
)
async def getRelevantDocs(self, query: str, kwargs: dict):
"""This should also post the result to firebase"""
print("retriver state", self.retriever.search_kwargs)
print("retriver state", self.retriever.search_type)
try:
for key in kwargs.keys():
if "search_type" in key:
self.retriever.search_type = kwargs[key]
else:
self.retriever.search_kwargs[key] = kwargs[key]
except:
print("setting search args failed")
print("reaching step2")
try:
# loop=asyncio.get_event_loop()
retVal = self.retriever.get_relevant_documents(query)
except Exception as ex:
logger.exception("Exception occured:", exc_info=True)
value = []
excludeMeta = True
print("reaching step3")
print(str(len(retVal)))
print("reaching step4")
try:
for item in retVal:
if excludeMeta:
v = item.page_content + " \n"
else:
v = "Info:" + item.page_content + " "
for key in item.metadata.keys():
if key != "ID":
v += key + ":" + str(item.metadata[key]) + " "
value.append(v)
print("reaching step5")
self.db_interface.add_to_cache(input=query, value=value)
except:
print("reaching step6")
for item in retVal:
if excludeMeta:
v = item["page_content"] + " \n"
else:
v = "Info:" + item["page_content"] + " "
for key in item["metadata"].keys():
if key != "ID":
v += key + ":" + str(item["metadata"][key]) + " "
value.append(v)
print("reaching step7")
self.db_interface.add_to_cache(input=query, value=value)
print("reaching step8")
return retVal
async def addText(self, inStr: str, metadata):
# metadata expected is some of following
# timestamp --> time when added
# source --> notes/references/web/youtube/book/conversation, default conversation
# title --> of document , will be conversation when source is conversation, default blank
# author --> will default to blank
##TODO: Preprocess inStr to remove any html, markdown tags etc.
metadata = metadata.dict()
if "timestamp" not in metadata.keys():
metadata["timestamp"] = datetime.now().isoformat()
else:
metadata["timestamp"] = datetime.fromisoformat(metadata["timestamp"])
pass
if "source" not in metadata.keys():
metadata["source"] = "conversation"
if "title" not in metadata.keys():
metadata["title"] = ""
if metadata["source"] == "conversation":
metadata["title"] == "conversation"
if "author" not in metadata.keys():
metadata["author"] = ""
# TODO: If url is present in input or when the splitting need to be done, then we'll need to change how we
# formulate the ID and may be filename to store information
metadata["ID"] = (
metadata["timestamp"].strftime("%Y-%m-%d %H-%M-%S")
+ "-"
+ metadata["title"]
)
metadata["Year"] = metadata["timestamp"].year
metadata["Month"] = metadata["timestamp"].month
metadata["Day"] = int(metadata["timestamp"].strftime("%d"))
metadata["Hour"] = metadata["timestamp"].hour
metadata["Minute"] = metadata["timestamp"].minute
metadata["timestamp"] = metadata["timestamp"].isoformat()
print("Metadata is:")
print(metadata)
# md.pop("timestamp")
with open("./docs/" + metadata["ID"] + ".txt", "w") as fd:
fd.write(inStr)
print("written to file", inStr)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=50,
length_function=len,
is_separator_regex=False,
)
# docs = [ Document(page_content=inStr, metadata=metadata)]
docs = text_splitter.create_documents([inStr], [metadata])
partNumber = 0
for doc in docs:
if partNumber > 0:
doc.metadata["ID"] += f"__{partNumber}"
partNumber += 1
print(f"{partNumber} follows:")
print(doc)
try:
print(metadata["ID"])
ids = [doc.metadata["ID"] for doc in docs]
print("ids are:")
print(ids)
return await self.vectorstore.aadd_documents(docs, ids=ids)
except Exception as ex:
logger.exception("exception in adding", exc_info=True)
print("inside expect of addText")
return await self.vectorstore.aadd_documents(docs, ids=[metadata.ID])
async def listDocs(self):
collection = self.vectorstore._client.get_collection(
self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME,
embedding_function=self.embedding,
)
return collection.get()
# return self.vectorstore._client._get(collection_id=self._uuid(collectionInfo.id))
async def persist(self):
self.vectorstore.persist()
await dbh.backupFile(self.logger_file)
await dbh.backupFolder(self.persist_db_directory)
return await dbh.backupFolder(self.persist_docs_directory)
def _uuid(self, uuid_str: str) -> UUID:
try:
return UUID(uuid_str)
except ValueError:
print("Error generating uuid")
raise ValueError(f"Could not parse {uuid_str} as a UUID")