Spaces:
Runtime error
Runtime error
File size: 4,783 Bytes
e2e8616 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import chromadb
from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block
from src.model.doc import Doc
from chromadb.utils import embedding_functions
import gradio as gr
class Chatbot:
def __init__(self, llm_agent : LlmAgent = None, retriever: Retriever = None, client_db=None):
self.retriever = retriever
self.llm = llm_agent
self.client_db = client_db
def get_response(self, query, histo):
histo_conversation, histo_queries = self._get_histo(histo)
language_of_query = self.llm.detect_language_v2(query).lower()
queries = self.llm.translate_v2(histo_queries)
if "en" in language_of_query:
language_of_query = "en"
else:
language_of_query = "fr"
block_sources = self.retriever.similarity_search(queries=queries)
block_sources = self._select_best_sources(block_sources)
sources_contents = [f"Paragraph title : {s.title}\n-----\n{s.content}" if s.title else f"Paragraph {s.index}\n-----\n{s.content}" for s in block_sources]
context = '\n'.join(sources_contents)
i = 1
while (len(context) + len(histo_conversation) > 15000) and i < len(sources_contents):
context = "\n".join(sources_contents[:-i])
i += 1
answer = self.llm.generate_paragraph_v2(query=query, histo=histo_conversation, context=context, language=language_of_query)
answer = self._clean_chatgpt_answer(answer)
return answer, block_sources
@staticmethod
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]:
"""
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
"""
best_sources = []
for idx, s in enumerate(sources):
if idx == 0 \
or (s.distance - sources[idx - 1].distance < delta_1_2
and s.distance - sources[0].distance < delta_1_n) \
or s.distance < absolute:
best_sources.append(s)
delta_1_2 *= alpha
delta_1_n *= alpha
absolute *= alpha
else:
break
return best_sources
@staticmethod
def _get_histo(histo: [(str, str)]) -> (str, str):
histo_conversation = ""
histo_queries = ""
for (query, answer) in histo[-5:]:
histo_conversation += f'user: {query} \n bot: {answer}\n'
histo_queries += query + '\n'
return histo_conversation[:-1], histo_queries
@staticmethod
def _clean_answer(answer: str) -> str:
print(answer)
answer = answer.strip('bot:')
while answer and answer[-1] in {"'", '"', " ", "`"}:
answer = answer[:-1]
while answer and answer[0] in {"'", '"', " ", "`"}:
answer = answer[1:]
answer = answer.strip('bot:')
if answer:
if answer[-1] != ".":
answer += "."
return answer
def _clean_chatgpt_answer(self,answer: str) -> str:
answer = answer.strip('bot:')
answer = answer.strip('Answer:')
answer = answer.strip('Réponse:')
while answer and answer[-1] in {"'", '"', " ", "`"}:
answer = answer[:-1]
return answer
def upload_doc(self,input_doc,include_images_,actual_page_start):
title = Doc.get_title(Doc,input_doc.name)
extension = title.split('.')[-1]
if extension and (extension == 'docx' or extension == 'pdf' or extension == 'html'):
open_ai_embedding = embedding_functions.OpenAIEmbeddingFunction(api_key=os.environ['OPENAI_API_KEY'], model_name="text-embedding-ada-002")
coll_name = "".join([c if c.isalnum() else "_" for c in title])
collection = self.client_db.get_or_create_collection(name=coll_name,embedding_function=open_ai_embedding)
if collection.count() == 0:
gr.Info("Please wait while your document is being analysed")
print("Database is empty")
doc = Doc(path=input_doc.name,include_images=include_images_,actual_first_page=actual_page_start)
# for block in doc.blocks: #DEBUG PART
# print(f"{block.index} : {block.content}")
retriever = Retriever(doc.container, collection=collection,llmagent=self.llm)
else:
print("Database is not empty")
retriever = Retriever(collection=collection,llmagent=self.llm)
self.retriever = retriever
else:
return False
return True |