File size: 5,973 Bytes
58974f8 a1dacf0 58974f8 ea54126 58974f8 6c8a93e ea54126 58974f8 ea54126 58974f8 a1dacf0 ea54126 a1dacf0 c3fe30b ea54126 6c8a93e ea54126 6c8a93e 58974f8 ea54126 58974f8 ea54126 8ae747a ea54126 58974f8 ea54126 6c8a93e ea54126 c3fe30b e5adec2 ea54126 e5adec2 c3fe30b 3ed86fa c3fe30b ea54126 6c8a93e e5adec2 6c8a93e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from .base_chatbot import BaseChatbot
from ..memory import BaseMemory, ChatMemory
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever
from ..refiner import BaseRefiner, SimpleRefiner
from models import BaseModel, GPT4Model
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt
from utils import convert_str_to_list
import ast
from utils.image_encoder import encode_image
import asyncio
import time
class RetrievalChatbot(BaseChatbot):
def __init__(self,
model: BaseModel = None,
memory: BaseMemory = None,
retriever: BaseRetriever = None,
decomposer: BaseRefiner = None,
answerer: BaseRefiner = None,
summarizer: BaseRefiner = None,
) -> None:
self.model = model if model \
else GPT4Model()
self.memory = memory if memory \
else ChatMemory(sys_prompt=SummaryPrompt.content)
self.retriever = retriever if retriever \
else ChromaRetriever(pdf_dir="papers_all",
collection_name="pdfs",
split_args={"size": 2048, "overlap": 10},
embed_model=GPT4Model())
self.decomposer = decomposer if decomposer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content)
self.answerer = answerer if answerer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content)
self.summarizer = summarizer if summarizer \
else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content)
async def response(self, message: str, image_paths=None) -> str:
time1 = time.time()
print("Query: {message}".format(message=message))
question = self.decomposer.refine(message, None, image_paths)
print(question)
# question = question.replace('"', "'").replace("\n", "").replace("', '", "','").lstrip("['").rstrip("']")
# sub_questions = question.split("','")
# print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions))
sub_questions_str = self.decomposer.refine(message, None, image_paths)
sub_questions_list = convert_str_to_list(sub_questions_str)
print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions_list))
tasks = []
time2 = time.time()
for sub_question in sub_questions_list:
# print("="*20)
# print(f"Subquestion: {sub_question}")
# print(f"Retrieving pdf papers for references...\n")
task = asyncio.create_task(self.subquestion_answerer(sub_question, image_paths))
tasks.append(task)
results = await asyncio.gather(*tasks)
references = ""
all_titles = set([])
for result in results:
references += result["answer"]
for t in result["titles"]:
all_titles.add(t)
logs = references
time3 = time.time()
print("Sub references are ",references)
refs, titles = self.retriever.retrieve(message)
for t in titles:
all_titles.add(t)
for ref in refs:
references += "Related research for the user query: {ref}\n".format(ref=ref)
summarizer_context = "Question References: {references}\nQuestion: {message}\n".format(references=references, message=message)
answer = self.summarizer.refine(summarizer_context, None, image_paths)
time4 = time.time()
#todo 记忆管理
if image_paths is None:
self.memory.append([{"role": "user", "content": [
{"type": "text", "text": f"{message}"},
]}, {"role": "assistant", "content": answer}])
else:
if not isinstance(image_paths, list):
image_paths = [image_paths]
memory_user = [{"type": "text", "text": f"{message}"},]
for image_path in image_paths:
memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},)
self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}])
print("="*20)
print(f"Final answer: {answer}".format(answer=answer))
print(f"Decompose: {time2-time1}")
print(f"Answer Subquestions: {time3-time2}")
print(f"Summarize: {time4-time3}")
return {
"answer": answer,
"titles": all_titles,
"logs": logs
}
async def subquestion_answerer(self, sub_question: str, image_paths=None, return_logs=False) -> str:
sub_retrieve_reference=""
time_s = time.time()
sub_retrieve, titles = self.retriever.retrieve(sub_question)
for ref in sub_retrieve:
sub_retrieve_reference += "Related research: {ref}\n".format(ref=ref)
sub_answerer_context = "Sub Question References: {sub_retrieve_reference}\nQuestion: {question}\n".format(sub_retrieve_reference=sub_retrieve_reference, question=sub_question)
refine_task = asyncio.create_task(self.answerer.refine_async(sub_answerer_context, self.memory, image_paths))
await refine_task
sub_answer = refine_task.result()
time_e = time.time()
print(f"Time: {time_e-time_s}")
print(f"Subanswer: {sub_answer}")
return {
"answer": "Subquestion: {sub_question}\nSubanswer: {sub_answer}\n\n\n".format(sub_question=sub_question, sub_answer=sub_answer),
"titles": titles
} |