Document_QnA / src /control /control.py
Quent1Fvr's picture
v2.
e2e8616
raw
history blame contribute delete
No virus
4.78 kB
import os
import chromadb
from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block
from src.model.doc import Doc
from chromadb.utils import embedding_functions
import gradio as gr
class Chatbot:
def __init__(self, llm_agent : LlmAgent = None, retriever: Retriever = None, client_db=None):
self.retriever = retriever
self.llm = llm_agent
self.client_db = client_db
def get_response(self, query, histo):
histo_conversation, histo_queries = self._get_histo(histo)
language_of_query = self.llm.detect_language_v2(query).lower()
queries = self.llm.translate_v2(histo_queries)
if "en" in language_of_query:
language_of_query = "en"
else:
language_of_query = "fr"
block_sources = self.retriever.similarity_search(queries=queries)
block_sources = self._select_best_sources(block_sources)
sources_contents = [f"Paragraph title : {s.title}\n-----\n{s.content}" if s.title else f"Paragraph {s.index}\n-----\n{s.content}" for s in block_sources]
context = '\n'.join(sources_contents)
i = 1
while (len(context) + len(histo_conversation) > 15000) and i < len(sources_contents):
context = "\n".join(sources_contents[:-i])
i += 1
answer = self.llm.generate_paragraph_v2(query=query, histo=histo_conversation, context=context, language=language_of_query)
answer = self._clean_chatgpt_answer(answer)
return answer, block_sources
@staticmethod
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]:
"""
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
"""
best_sources = []
for idx, s in enumerate(sources):
if idx == 0 \
or (s.distance - sources[idx - 1].distance < delta_1_2
and s.distance - sources[0].distance < delta_1_n) \
or s.distance < absolute:
best_sources.append(s)
delta_1_2 *= alpha
delta_1_n *= alpha
absolute *= alpha
else:
break
return best_sources
@staticmethod
def _get_histo(histo: [(str, str)]) -> (str, str):
histo_conversation = ""
histo_queries = ""
for (query, answer) in histo[-5:]:
histo_conversation += f'user: {query} \n bot: {answer}\n'
histo_queries += query + '\n'
return histo_conversation[:-1], histo_queries
@staticmethod
def _clean_answer(answer: str) -> str:
print(answer)
answer = answer.strip('bot:')
while answer and answer[-1] in {"'", '"', " ", "`"}:
answer = answer[:-1]
while answer and answer[0] in {"'", '"', " ", "`"}:
answer = answer[1:]
answer = answer.strip('bot:')
if answer:
if answer[-1] != ".":
answer += "."
return answer
def _clean_chatgpt_answer(self,answer: str) -> str:
answer = answer.strip('bot:')
answer = answer.strip('Answer:')
answer = answer.strip('Réponse:')
while answer and answer[-1] in {"'", '"', " ", "`"}:
answer = answer[:-1]
return answer
def upload_doc(self,input_doc,include_images_,actual_page_start):
title = Doc.get_title(Doc,input_doc.name)
extension = title.split('.')[-1]
if extension and (extension == 'docx' or extension == 'pdf' or extension == 'html'):
open_ai_embedding = embedding_functions.OpenAIEmbeddingFunction(api_key=os.environ['OPENAI_API_KEY'], model_name="text-embedding-ada-002")
coll_name = "".join([c if c.isalnum() else "_" for c in title])
collection = self.client_db.get_or_create_collection(name=coll_name,embedding_function=open_ai_embedding)
if collection.count() == 0:
gr.Info("Please wait while your document is being analysed")
print("Database is empty")
doc = Doc(path=input_doc.name,include_images=include_images_,actual_first_page=actual_page_start)
# for block in doc.blocks: #DEBUG PART
# print(f"{block.index} : {block.content}")
retriever = Retriever(doc.container, collection=collection,llmagent=self.llm)
else:
print("Database is not empty")
retriever = Retriever(collection=collection,llmagent=self.llm)
self.retriever = retriever
else:
return False
return True