audit_assistant / auditqa /retriever.py
ppsingh's picture
Create retriever.py
71aaf00 verified
raw
history blame
2.48 kB
from qdrant_client.http import models as rest
from auditqa.process_chunks import getconfig
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import logging
model_config = getconfig("model_params.cfg")
def create_filter(reports:list = [],sources:str =None,
subtype:str =None,year:str =None):
if len(reports) == 0:
print("defining filter for:{}:{}:{}".format(sources,subtype,year))
filter=rest.Filter(
must=[rest.FieldCondition(
key="metadata.source",
match=rest.MatchValue(value=sources)
),
rest.FieldCondition(
key="metadata.subtype",
match=rest.MatchValue(value=subtype)
),
rest.FieldCondition(
key="metadata.year",
match=rest.MatchAny(any=year)
),])
else:
print("defining filter for allreports:",reports)
filter=rest.Filter(
must=[
rest.FieldCondition(
key="metadata.filename",
match=rest.MatchAny(any=reports)
)])
return filter
def get_context(vectorstore,query,reports,sources,subtype,year):
# create metadata filter
filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year)
# getting context
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.6,
"k": int(model_config.get('retriever','TOP_K')),
"filter":filter})
# re-ranking the retrieved results
model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
context_retrieved = compression_retriever.invoke(query)
print(f"retrieved paragraphs:{len(context_retrieved)}")
return context_retrieved