File size: 5,973 Bytes
58974f8
 
 
 
 
 
a1dacf0
58974f8
 
ea54126
 
58974f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8a93e
ea54126
58974f8
ea54126
58974f8
a1dacf0
 
 
 
 
 
ea54126
 
a1dacf0
c3fe30b
 
 
ea54126
 
 
6c8a93e
 
 
 
 
 
 
 
 
ea54126
 
6c8a93e
 
 
58974f8
 
 
 
ea54126
 
 
58974f8
ea54126
8ae747a
 
 
 
ea54126
 
 
 
 
 
58974f8
 
ea54126
 
 
 
6c8a93e
 
 
 
 
 
ea54126
 
 
c3fe30b
e5adec2
ea54126
 
e5adec2
c3fe30b
 
 
3ed86fa
c3fe30b
ea54126
6c8a93e
 
e5adec2
6c8a93e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from .base_chatbot import BaseChatbot
from ..memory import BaseMemory, ChatMemory
from ..retriever import BaseRetriever, ChromaRetriever, FaissRetriever
from ..refiner import BaseRefiner, SimpleRefiner
from models import BaseModel, GPT4Model
from prompts import DecomposePrompt, QAPrompt, SummaryPrompt, ReferencePrompt
from utils import convert_str_to_list
import ast
from utils.image_encoder import encode_image
import asyncio
import time

class RetrievalChatbot(BaseChatbot):
    def __init__(self, 
                 model: BaseModel = None, 
                 memory: BaseMemory = None,
                 retriever: BaseRetriever = None,
                 decomposer: BaseRefiner = None,
                 answerer: BaseRefiner = None,
                 summarizer: BaseRefiner = None,
        ) -> None:
        self.model = model if model \
                           else GPT4Model()
        self.memory = memory if memory \
                             else ChatMemory(sys_prompt=SummaryPrompt.content)
        self.retriever = retriever if retriever \
                                   else ChromaRetriever(pdf_dir="papers_all", 
                                                        collection_name="pdfs", 
                                                        split_args={"size": 2048, "overlap": 10},
                                                        embed_model=GPT4Model())
        self.decomposer = decomposer if decomposer \
                                     else SimpleRefiner(model=GPT4Model(), sys_prompt=DecomposePrompt.content)
        self.answerer = answerer if answerer \
                                 else SimpleRefiner(model=GPT4Model(), sys_prompt=QAPrompt.content)
        self.summarizer = summarizer if summarizer \
                                     else SimpleRefiner(model=GPT4Model(), sys_prompt=SummaryPrompt.content)
    
    async def response(self, message: str, image_paths=None) -> str:
        time1 = time.time()
        print("Query: {message}".format(message=message))
        question = self.decomposer.refine(message, None, image_paths)
        print(question)
        # question = question.replace('"', "'").replace("\n", "").replace("', '", "','").lstrip("['").rstrip("']")
        # sub_questions = question.split("','")
        # print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions))
        sub_questions_str = self.decomposer.refine(message, None, image_paths)
        sub_questions_list = convert_str_to_list(sub_questions_str)
        print("Decomposed your query into subquestions: {sub_questions}".format(sub_questions=sub_questions_list))
        tasks = []
        time2 = time.time()
        for sub_question in sub_questions_list:
            # print("="*20)
            # print(f"Subquestion: {sub_question}")            
            # print(f"Retrieving pdf papers for references...\n")
            task = asyncio.create_task(self.subquestion_answerer(sub_question, image_paths))
            tasks.append(task)
        results = await asyncio.gather(*tasks)

        references = ""
        all_titles = set([])
        for result in results:
            references += result["answer"]
            for t in result["titles"]:
                all_titles.add(t)
        logs = references

        time3 = time.time()
        print("Sub references are ",references)
        refs, titles = self.retriever.retrieve(message)
        for t in titles:
            all_titles.add(t)
        for ref in refs:
            references += "Related research for the user query: {ref}\n".format(ref=ref)

        summarizer_context = "Question References: {references}\nQuestion: {message}\n".format(references=references, message=message)
        answer = self.summarizer.refine(summarizer_context, None, image_paths)
        time4 = time.time()

        #todo 记忆管理
        if image_paths is None:
            self.memory.append([{"role": "user", "content": [
                            {"type": "text", "text": f"{message}"},
                        ]}, {"role": "assistant", "content": answer}])
        else:
            if not isinstance(image_paths, list):
                image_paths = [image_paths]
            memory_user = [{"type": "text", "text": f"{message}"},]
            for image_path in image_paths:
                memory_user.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}},)
            self.memory.append([{"role": "user", "content": memory_user}, {"role": "assistant", "content": answer}])
        print("="*20)
        print(f"Final answer: {answer}".format(answer=answer))

        print(f"Decompose: {time2-time1}")
        print(f"Answer Subquestions: {time3-time2}")
        print(f"Summarize: {time4-time3}")

        return {
            "answer": answer,
            "titles": all_titles,
            "logs": logs
        }
        
    async def subquestion_answerer(self, sub_question: str, image_paths=None, return_logs=False) -> str:
        sub_retrieve_reference=""
        time_s = time.time()
        sub_retrieve, titles = self.retriever.retrieve(sub_question)
        for ref in sub_retrieve:
            sub_retrieve_reference += "Related research: {ref}\n".format(ref=ref)
        sub_answerer_context = "Sub Question References: {sub_retrieve_reference}\nQuestion: {question}\n".format(sub_retrieve_reference=sub_retrieve_reference, question=sub_question)
        refine_task = asyncio.create_task(self.answerer.refine_async(sub_answerer_context, self.memory, image_paths))
        await refine_task
        sub_answer = refine_task.result()
        time_e = time.time()
        print(f"Time: {time_e-time_s}")
        print(f"Subanswer: {sub_answer}")
        return {
            "answer": "Subquestion: {sub_question}\nSubanswer: {sub_answer}\n\n\n".format(sub_question=sub_question, sub_answer=sub_answer), 
            "titles": titles
        }