|
from .base_chatbot import BaseChatbot |
|
from ..memory import BaseMemory, ChatMemory |
|
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever |
|
from ..refiner import BaseRefiner, SimpleRefiner |
|
from models import BaseModel, GPT4Model |
|
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt |
|
import ast |
|
from utils.image_encoder import encode_image |
|
import asyncio |
|
import time |
|
|
|
class RetrievalChatbot(BaseChatbot): |
|
def __init__(self, |
|
model: BaseModel = None, |
|
memory: BaseMemory = None, |
|
retriever: BaseRetriever = None, |
|
decomposer: BaseRefiner = None, |
|
answerer: BaseRefiner = None, |
|
summarizer: BaseRefiner = None, |
|
) -> None: |
|
self.model = model if model \ |
|
else GPT4Model() |
|
self.memory = memory if memory \ |
|
else ChatMemory(sys_prompt=SummaryPrompt.content) |
|
self.retriever = retriever if retriever \ |
|
else ChromaRetriever(pdf_dir="papers_all", |
|
collection_name="pdfs", |
|
split_args={"size": 2048, "overlap": 10}, |
|
embed_model=GPT4Model()) |
|
self.decomposer = decomposer if decomposer \ |
|
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content) |
|
self.answerer = answerer if answerer \ |
|
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content) |
|
self.summarizer = summarizer if summarizer \ |
|
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content) |
|
|
|
async def response(self, message: str, image_paths=None, return_logs=False) -> str: |
|
time1 = time.time() |
|
print("Query: {message}".format(message=message)) |
|
question = self.decomposer.refine(message, None, image_paths) |
|
print(question) |
|
question = question.replace('"', "'").replace("\n", "").replace("', '", "','").lstrip("['").rstrip("']") |
|
sub_questions = question.split("','") |
|
print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions)) |
|
tasks = [] |
|
time2 = time.time() |
|
for sub_question in sub_questions: |
|
|
|
|
|
|
|
task = asyncio.create_task(self.subquestion_answerer(sub_question, image_paths)) |
|
tasks.append(task) |
|
results = await asyncio.gather(*tasks) |
|
references = "".join(results) |
|
time3 = time.time() |
|
print("Sub references are ",references) |
|
refs = self.retriever.retrieve(message) |
|
for ref in refs: |
|
references += "Related research for the user query: {ref}\n".format(ref=ref) |
|
|
|
summarizer_context = "Question References: {references}\nQuestion: {message}\n".format(references=references, message=message) |
|
answer = self.summarizer.refine(summarizer_context, None, image_paths) |
|
time4 = time.time() |
|
|
|
|
|
if image_paths is None: |
|
self.memory.append([{"role": "user", "content": [ |
|
{"type": "text", "text": f"{message}"}, |
|
]}, {"role": "assistant", "content": answer}]) |
|
else: |
|
if not isinstance(image_paths, list): |
|
image_paths = [image_paths] |
|
memory_user = [{"type": "text", "text": f"{message}"},] |
|
for image_path in image_paths: |
|
memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},) |
|
self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}]) |
|
print("="*20) |
|
print(f"Final answer: {answer}".format(answer=answer)) |
|
|
|
print(f"Decompose: {time2-time1}") |
|
print(f"Answer Subquestions: {time3-time2}") |
|
print(f"Summarize: {time4-time3}") |
|
|
|
if return_logs: |
|
return answer, references |
|
else: |
|
return answer |
|
|
|
async def subquestion_answerer(self, sub_question: str, image_paths=None, return_logs=False) -> str: |
|
sub_retrieve_reference="" |
|
time_s = time.time() |
|
sub_retrieve = self.retriever.retrieve(sub_question) |
|
for ref in sub_retrieve: |
|
sub_retrieve_reference += "Related research: {ref}\n".format(ref=ref) |
|
sub_answerer_context = "Sub Question References: {sub_retrieve_reference}\nQuestion: {question}\n".format(sub_retrieve_reference=sub_retrieve_reference, question=sub_question) |
|
refine_task = asyncio.create_task(self.answerer.refine_async(sub_answerer_context, self.memory, image_paths)) |
|
await refine_task |
|
sub_answer = refine_task.result() |
|
time_e = time.time() |
|
print(f"Time: {time_e-time_s}") |
|
print(f"Subanswer: {sub_answer}") |
|
return "Subquestion: {sub_question}\nSubanswer: {sub_answer}\n\n\n".format(sub_question=sub_question, sub_answer=sub_answer) |