protein-retrieval-base / core /chatbot /retrieval_chatbot.py
lindsay-qu's picture
Upload 86 files
e0f406c verified
raw
history blame contribute delete
No virus
3.48 kB
from .base_chatbot import BaseChatbot
from ..memory import BaseMemory, ChatMemory
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever
from ..refiner import BaseRefiner, SimpleRefiner
from models import BaseModel, GPT4Model
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt
import ast
from utils.image_encoder import encode_image
import asyncio
import time
class RetrievalChatbot(BaseChatbot):
def __init__(self,
model: BaseModel = None,
memory: BaseMemory = None,
retriever: BaseRetriever = None,
decomposer: BaseRefiner = None,
answerer: BaseRefiner = None,
summarizer: BaseRefiner = None,
) -> None:
self.model = model if model \
else GPT4Model()
self.memory = memory if memory \
else ChatMemory(sys_prompt=SummaryPrompt.content)
self.retriever = retriever if retriever \
else ChromaRetriever(pdf_dir="papers_all",
collection_name="pdfs",
split_args={"size": 2048, "overlap": 10},
embed_model=GPT4Model())
self.decomposer = decomposer if decomposer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content)
self.answerer = answerer if answerer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content)
self.summarizer = summarizer if summarizer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content)
async def response(self, message: str, image_paths=None, return_logs=False) -> str:
time1 = time.time()
print("Query: {message}".format(message=message))
retrieved_reference=""
time_s = time.time()
results = self.retriever.retrieve(message)
refs, titles = results
for ref in refs:
retrieved_reference += "Related research: {ref}\n".format(ref=ref)
answerer_context = "Sub Question References: {retrieved_reference}\nQuestion: {message}\n".format(retrieved_reference=retrieved_reference, message=message)
answer = self.answerer.refine(answerer_context, self.memory, image_paths)
time_e = time.time()
#todo 记忆管理
if image_paths is None:
self.memory.append([{"role": "user", "content": [
{"type": "text", "text": f"{message}"},
]}, {"role": "assistant", "content": answer}])
else:
if not isinstance(image_paths, list):
image_paths = [image_paths]
memory_user = [{"type": "text", "text": f"{message}"},]
for image_path in image_paths:
memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},)
self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}])
print("="*20)
print(f"Final answer: {answer}".format(answer=answer))
return {
"answer": answer,
"titles": set(titles),
"logs": ""
}