Document_QnA / src /tools /retriever.py
Quent1Fvr's picture
v2.
e2e8616
raw
history blame
No virus
2.14 kB
from src.model.block import Block
from src.model.doc import Doc
from src.tools.llm import LlmAgent
import gradio as gr
class Retriever:
def __init__(self, doc : Doc = None, collection = None, llmagent : LlmAgent = None):
if doc != None:
blocks_good_format: [Block] = doc.blocks
self.collection = collection
gr.Info("Please wait while the database is being created")
for block in blocks_good_format:
if len(block.content) > 4500:
new_blocks = block.separate_1_block_in_n(max_size=4500)
for new_block in new_blocks:
summary = llmagent.summarize_paragraph_v2(prompt=new_block.content,title_doc=doc.title,title_para=block.title)
if "<summary>" in summary:
summary = summary.split("<summary>")[1]
self.collection.add(
documents=[summary],
ids=[new_block.index],
metadatas=[new_block.to_dict()]
)
else:
summary = llmagent.summarize_paragraph_v2(prompt=block.content,title_doc=doc.title,title_para=block.title)
if "<summary>" in summary:
summary = summary.split("<summary>")[1]
self.collection.add(
documents=[summary],
ids=[block.index],
metadatas=[block.to_dict()]
)
gr.Info(f"The collection {collection.name} has been added to the database")
else:
self.collection = collection
def similarity_search(self, queries: str) -> {}:
res = self.collection.query(query_texts=queries,n_results=6)
block_dict_sources = res['metadatas'][0]
distances = res['distances'][0]
blocks = []
for bd, d in zip(block_dict_sources, distances):
b = Block().from_dict(bd)
b.distance = d
blocks.append(b)
return blocks