from .base_chatbot import BaseChatbot from ..memory import BaseMemory, ChatMemory from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever from ..refiner import BaseRefiner, SimpleRefiner from models import BaseModel, GPT4Model from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt import ast from utils.image_encoder import encode_image import asyncio import time class RetrievalChatbot(BaseChatbot): def __init__(self, model: BaseModel = None, memory: BaseMemory = None, retriever: BaseRetriever = None, decomposer: BaseRefiner = None, answerer: BaseRefiner = None, summarizer: BaseRefiner = None, ) -> None: self.model = model if model \ else GPT4Model() self.memory = memory if memory \ else ChatMemory(sys_prompt=SummaryPrompt.content) self.retriever = retriever if retriever \ else ChromaRetriever(pdf_dir="papers_all", collection_name="pdfs", split_args={"size": 2048, "overlap": 10}, embed_model=GPT4Model()) self.decomposer = decomposer if decomposer \ else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content) self.answerer = answerer if answerer \ else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content) self.summarizer = summarizer if summarizer \ else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content) async def response(self, message: str, image_paths=None, return_logs=False) -> str: time1 = time.time() print("Query: {message}".format(message=message)) retrieved_reference="" time_s = time.time() results = self.retriever.retrieve(message) refs, titles = results for ref in refs: retrieved_reference += "Related research: {ref}\n".format(ref=ref) answerer_context = "Sub Question References: {retrieved_reference}\nQuestion: {message}\n".format(retrieved_reference=retrieved_reference, message=message) answer = self.answerer.refine(answerer_context, self.memory, image_paths) time_e = time.time() #todo 记忆管理 if image_paths is None: self.memory.append([{"role": "user", "content": [ {"type": "text", "text": f"{message}"}, ]}, {"role": "assistant", "content": answer}]) else: if not isinstance(image_paths, list): image_paths = [image_paths] memory_user = [{"type": "text", "text": f"{message}"},] for image_path in image_paths: memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},) self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}]) print("="*20) print(f"Final answer: {answer}".format(answer=answer)) return { "answer": answer, "titles": set(titles), "logs": "" }