File size: 2,963 Bytes
328b268
 
 
 
 
 
 
 
 
 
e182c41
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2826548
 
 
 
 
 
 
 
 
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef1ef76
 
 
 
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import abc
import os
import time
import urllib
from queue import Queue
from threading import Thread

from langchain.callbacks.tracers import LangChainTracer
from langchain.chains.base import Chain

from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces


class LLMInference(metaclass=abc.ABCMeta):
    llm_loader: LLMLoader
    chain: Chain

    def __init__(self, llm_loader):
        self.llm_loader = llm_loader
        self.chain = None

    @abc.abstractmethod
    def create_chain(self) -> Chain:
        pass

    def get_chain(self, tracing: bool = False) -> Chain:
        if self.chain is None:
            if tracing:
                tracer = LangChainTracer()
                tracer.load_default_session()

            self.chain = self.create_chain()

        return self.chain

    def call_chain(
        self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
    ):
        print(inputs)

        if self.llm_loader.streamer is not None and isinstance(
            self.llm_loader.streamer, TextIteratorStreamer
        ):
            self.llm_loader.streamer.reset(q)

        chain = self.get_chain(tracing)
        result = (
            self._run_qa_chain(
                chain,
                inputs,
                streaming_handler,
            )
            if streaming_handler is not None
            else chain(inputs)
        )

        if "answer" in result:
            result["answer"] = remove_extra_spaces(result["answer"])

            base_url = os.environ.get("PDF_FILE_BASE_URL")
            if base_url is not None and len(base_url) > 0:
                documents = result["source_documents"]
                for doc in documents:
                    source = doc.metadata["source"]
                    title = source.split("/")[-1]
                    doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"

        return result

    def _run_qa_chain(self, qa, inputs, streaming_handler):
        que = Queue()

        t = Thread(
            target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])),
            args=(qa, inputs, que, streaming_handler),
        )
        t.start()

        if self.llm_loader.streamer is not None and isinstance(
            self.llm_loader.streamer, TextIteratorStreamer
        ):
            count = (
                2
                if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
                else 1
            )

            while count > 0:
                try:
                    for token in self.llm_loader.streamer:
                        streaming_handler.on_llm_new_token(token)

                    self.llm_loader.streamer.reset()
                    count -= 1
                except Exception:
                    print("nothing generated yet - retry in 0.5s")
                    time.sleep(0.5)

        t.join()
        return que.get()