lindsay-qu's picture
Upload 92 files
58974f8
raw
history blame
No virus
5.57 kB
from .base_chatbot import BaseChatbot
from ..memory import BaseMemory, ChatMemory
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever
from ..refiner import BaseRefiner, SimpleRefiner
from models import BaseModel, GPT4Model
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt
import ast
from utils.image_encoder import encode_image
# QA_PROMPT = "\
# You are a Question-Answering Chatbot. \
# Given some references and a question, please answer the question according to the references. \
# If you find the references insufficient, you can answer the question according to your own knowledge. \
# ONLY output the answer. \
# "
# QUESTION_PROMPT = "\
# You are a Question Refiner. \
# Given a question, you need to break it down to several subquestions and output a list of string: [\"<subquestion1>\", \"<subquestion2>\", ...]. \
# MAKE SURE there are no vague concepts in each subquestion that require reference to other subquestions, such as determiners, pronominal and so on. \
# If the question cannot be broken down, you need to rephrase it in 3 ways and output a list of string: [\"<rephrase1>\", \"<rephrase2>\", \"<rephrase3>\"]. \
# ONLY output the list of subquestions or rephrases. \
# "
# SUMMARY_PROMPT = "\
# You are a Summary Refiner. \
# Given a question and several answers to it, you need to organize and summarize the answers to form one coherent answer to the question. \
# ONLY output the summarized answer. \
# "
# REFERENCE_PROMPT = "\
# You are a Reference Refiner. \
# Given paragraphs extract from a paper, you need to remove the unnecessary and messy symbols to make it more readable. \
# But keep the original expression and sentences as much as possible. \
# ONLY output the refined paragraphs. \
# "
class RetrievalChatbot(BaseChatbot):
def __init__(self,
model: BaseModel = None,
memory: BaseMemory = None,
retriever: BaseRetriever = None,
decomposer: BaseRefiner = None,
answerer: BaseRefiner = None,
summarizer: BaseRefiner = None,
) -> None:
self.model = model if model \
else GPT4Model()
self.memory = memory if memory \
else ChatMemory(sys_prompt=SummaryPrompt.content)
self.retriever = retriever if retriever \
else ChromaRetriever(pdf_dir="papers_all",
collection_name="pdfs",
split_args={"size": 2048, "overlap": 10},
embed_model=GPT4Model())
self.decomposer = decomposer if decomposer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content)
self.answerer = answerer if answerer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content)
self.summarizer = summarizer if summarizer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content)
def response(self, message: str, image_path=None) -> str:
print("Query: {message}".format(message=message))
question = self.decomposer.refine(message,image_path)
print(question)
sub_questions = ast.literal_eval(question)
print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions))
references = ""
for sub_question in sub_questions:
print("="*20)
print(f"Subquestion: {sub_question}")
print(f"Retrieving pdf papers for references...\n")
sub_retrieve_reference = references
sub_retrieve = self.retriever.retrieve(sub_question)
for ref in sub_retrieve:
sub_retrieve_reference += "Related research: {ref}\n".format(ref=ref)
# context = self.memory.messages + [{"role": "user", "content": "References: {references}\nQuestion: {question}".format(references=reference, question=sub_question)}]
# sub_answer = self.model.respond(context)
sub_answerer_context = "Sub Question References: {sub_retrieve_reference}\nQuestion: {question}\n".format(sub_retrieve_reference=sub_retrieve_reference, question=sub_question)
sub_answer = self.answerer.refine(sub_answerer_context,image_path)
print(f"Subanswer: {sub_answer}")
references += "Subquestion: {sub_question}\nSubanswer: {sub_answer}\n".format(sub_question=sub_question, sub_answer=sub_answer)
refs = self.retriever.retrieve(message)
for ref in refs:
references += "Related research for the user query: {ref}\n".format(ref=ref)
summarizer_context = "Question References: {references}\nQuestion: {message}\n".format(references=references, message=message)
answer = self.summarizer.refine(summarizer_context,image_path)
#todo 记忆管理
self.memory.append([{"role": "user", "content": [
{"type": "text", "text": f"{message}"},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}},
]}, {"role": "assistant", "content": answer}])
print("="*20)
print(f"Final answer: {answer}".format(answer=answer))
return answer