|
from .base_chatbot import BaseChatbot |
|
from ..memory import BaseMemory, ChatMemory |
|
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever |
|
from ..refiner import BaseRefiner, SimpleRefiner |
|
from models import BaseModel, GPT4Model |
|
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt |
|
import ast |
|
from utils.image_encoder import encode_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RetrievalChatbot(BaseChatbot): |
|
def __init__(self, |
|
model: BaseModel = None, |
|
memory: BaseMemory = None, |
|
retriever: BaseRetriever = None, |
|
decomposer: BaseRefiner = None, |
|
answerer: BaseRefiner = None, |
|
summarizer: BaseRefiner = None, |
|
) -> None: |
|
self.model = model if model \ |
|
else GPT4Model() |
|
self.memory = memory if memory \ |
|
else ChatMemory(sys_prompt=SummaryPrompt.content) |
|
self.retriever = retriever if retriever \ |
|
else ChromaRetriever(pdf_dir="papers_all", |
|
collection_name="pdfs", |
|
split_args={"size": 2048, "overlap": 10}, |
|
embed_model=GPT4Model()) |
|
self.decomposer = decomposer if decomposer \ |
|
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content) |
|
self.answerer = answerer if answerer \ |
|
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content) |
|
self.summarizer = summarizer if summarizer \ |
|
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content) |
|
|
|
def response(self, message: str, image_path=None) -> str: |
|
print("Query: {message}".format(message=message)) |
|
question = self.decomposer.refine(message,image_path) |
|
print(question) |
|
sub_questions = ast.literal_eval(question) |
|
print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions)) |
|
references = "" |
|
for sub_question in sub_questions: |
|
print("="*20) |
|
print(f"Subquestion: {sub_question}") |
|
|
|
print(f"Retrieving pdf papers for references...\n") |
|
sub_retrieve_reference = references |
|
sub_retrieve = self.retriever.retrieve(sub_question) |
|
for ref in sub_retrieve: |
|
sub_retrieve_reference += "Related research: {ref}\n".format(ref=ref) |
|
|
|
|
|
sub_answerer_context = "Sub Question References: {sub_retrieve_reference}\nQuestion: {question}\n".format(sub_retrieve_reference=sub_retrieve_reference, question=sub_question) |
|
sub_answer = self.answerer.refine(sub_answerer_context,image_path) |
|
|
|
print(f"Subanswer: {sub_answer}") |
|
|
|
references += "Subquestion: {sub_question}\nSubanswer: {sub_answer}\n".format(sub_question=sub_question, sub_answer=sub_answer) |
|
|
|
refs = self.retriever.retrieve(message) |
|
for ref in refs: |
|
references += "Related research for the user query: {ref}\n".format(ref=ref) |
|
|
|
summarizer_context = "Question References: {references}\nQuestion: {message}\n".format(references=references, message=message) |
|
answer = self.summarizer.refine(summarizer_context,image_path) |
|
|
|
|
|
self.memory.append([{"role": "user", "content": [ |
|
{"type": "text", "text": f"{message}"}, |
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}}, |
|
]}, {"role": "assistant", "content": answer}]) |
|
print("="*20) |
|
print(f"Final answer: {answer}".format(answer=answer)) |
|
return answer |