File size: 2,137 Bytes
e2e8616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from src.model.block import Block
from src.model.doc import Doc
from src.tools.llm import LlmAgent
import gradio as gr

class Retriever:
    def __init__(self, doc : Doc = None, collection = None, llmagent : LlmAgent = None):
        if doc != None:
            blocks_good_format: [Block] = doc.blocks
            self.collection = collection
            gr.Info("Please wait while the database is being created")
            for block in blocks_good_format:
                if len(block.content) > 4500:
                    new_blocks = block.separate_1_block_in_n(max_size=4500)
                    for new_block in new_blocks:
                        summary = llmagent.summarize_paragraph_v2(prompt=new_block.content,title_doc=doc.title,title_para=block.title)
                        if "<summary>" in summary:
                            summary = summary.split("<summary>")[1]
                        self.collection.add(
                            documents=[summary],
                            ids=[new_block.index],
                            metadatas=[new_block.to_dict()]
                        )
                else:
                    summary = llmagent.summarize_paragraph_v2(prompt=block.content,title_doc=doc.title,title_para=block.title)
                    if "<summary>" in summary:
                        summary = summary.split("<summary>")[1]
                    self.collection.add(
                        documents=[summary],
                        ids=[block.index],
                        metadatas=[block.to_dict()]
                    )
            gr.Info(f"The collection {collection.name} has been added to the database")
        else:
            self.collection = collection



    def similarity_search(self, queries: str) -> {}:
        res = self.collection.query(query_texts=queries,n_results=6)
        block_dict_sources = res['metadatas'][0]
        distances = res['distances'][0]
        blocks = []
        for bd, d in zip(block_dict_sources, distances):
            b = Block().from_dict(bd)
            b.distance = d
            blocks.append(b)
        return blocks