Spaces:
Sleeping
Sleeping
from .base_chatbot import BaseChatbot | |
from ..memory import BaseMemory, ChatMemory | |
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever | |
from ..refiner import BaseRefiner, SimpleRefiner | |
from models import BaseModel, GPT4Model | |
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt | |
import ast | |
from utils.image_encoder import encode_image | |
import asyncio | |
import time | |
class RetrievalChatbot(BaseChatbot): | |
def __init__(self, | |
model: BaseModel = None, | |
memory: BaseMemory = None, | |
retriever: BaseRetriever = None, | |
decomposer: BaseRefiner = None, | |
answerer: BaseRefiner = None, | |
summarizer: BaseRefiner = None, | |
) -> None: | |
self.model = model if model \ | |
else GPT4Model() | |
self.memory = memory if memory \ | |
else ChatMemory(sys_prompt=SummaryPrompt.content) | |
self.retriever = retriever if retriever \ | |
else ChromaRetriever(pdf_dir="papers_all", | |
collection_name="pdfs", | |
split_args={"size": 2048, "overlap": 10}, | |
embed_model=GPT4Model()) | |
self.decomposer = decomposer if decomposer \ | |
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content) | |
self.answerer = answerer if answerer \ | |
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content) | |
self.summarizer = summarizer if summarizer \ | |
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content) | |
async def response(self, message: str, image_paths=None, return_logs=False) -> str: | |
time1 = time.time() | |
print("Query: {message}".format(message=message)) | |
retrieved_reference="" | |
time_s = time.time() | |
results = self.retriever.retrieve(message) | |
refs, titles = results | |
for ref in refs: | |
retrieved_reference += "Related research: {ref}\n".format(ref=ref) | |
answerer_context = "Sub Question References: {retrieved_reference}\nQuestion: {message}\n".format(retrieved_reference=retrieved_reference, message=message) | |
answer = self.answerer.refine(answerer_context, self.memory, image_paths) | |
time_e = time.time() | |
#todo 记忆管理 | |
if image_paths is None: | |
self.memory.append([{"role": "user", "content": [ | |
{"type": "text", "text": f"{message}"}, | |
]}, {"role": "assistant", "content": answer}]) | |
else: | |
if not isinstance(image_paths, list): | |
image_paths = [image_paths] | |
memory_user = [{"type": "text", "text": f"{message}"},] | |
for image_path in image_paths: | |
memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},) | |
self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}]) | |
print("="*20) | |
print(f"Final answer: {answer}".format(answer=answer)) | |
return { | |
"answer": answer, | |
"titles": set(titles), | |
"logs": "" | |
} |