File size: 3,249 Bytes
328b268
 
 
 
 
 
6011708
328b268
 
 
e182c41
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
bf1e59b
328b268
 
 
 
 
6011708
 
 
328b268
6b469d2
 
 
 
 
328b268
 
d5af465
 
328b268
3ca5bd8
328b268
 
bf1e59b
3ca5bd8
6011708
 
 
3ca5bd8
6011708
328b268
 
3ca5bd8
 
2826548
3ca5bd8
 
 
 
 
 
 
328b268
3ca5bd8
 
d5af465
 
328b268
95d2e5f
6011708
95d2e5f
6011708
 
 
3ca5bd8
328b268
3ca5bd8
 
 
 
 
4cae0a4
bf1e59b
734948a
 
 
 
 
4cae0a4
734948a
 
 
bf1e59b
 
4cae0a4
734948a
 
 
bf1e59b
 
734948a
3ca5bd8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import abc
import os
import time
import urllib
from queue import Queue
from threading import Thread
from typing import List, Optional

from langchain.chains.base import Chain

from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces


class LLMInference(metaclass=abc.ABCMeta):
    llm_loader: LLMLoader
    chain: Chain

    def __init__(self, llm_loader):
        self.llm_loader = llm_loader
        self.chain = None

    @abc.abstractmethod
    def create_chain(self) -> Chain:
        pass

    def get_chain(self) -> Chain:
        if self.chain is None:
            self.chain = self.create_chain()

        return self.chain

    def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
        return chain(inputs, callbacks)

    def call_chain(
        self,
        inputs,
        streaming_handler,
        q: Queue = None,
        testing: bool = False,
    ):
        print(inputs)
        if self.llm_loader.streamer.for_huggingface:
            self.llm_loader.lock.acquire()

        try:
            self.llm_loader.streamer.reset(q)

            chain = self.get_chain()
            result = (
                self._run_chain_with_streaming_handler(
                    chain, inputs, streaming_handler, testing
                )
                if streaming_handler is not None
                else self.run_chain(chain, inputs)
            )

            if "answer" in result:
                result["answer"] = remove_extra_spaces(result["answer"])

                base_url = os.environ.get("PDF_FILE_BASE_URL")
                if base_url is not None and len(base_url) > 0:
                    documents = result["source_documents"]
                    for doc in documents:
                        source = doc.metadata["source"]
                        title = source.split("/")[-1]
                        doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"

            return result
        finally:
            if self.llm_loader.streamer.for_huggingface:
                self.llm_loader.lock.release()

    def _execute_chain(self, chain, inputs, q, sh):
        q.put(self.run_chain(chain, inputs, callbacks=[sh]))

    def _run_chain_with_streaming_handler(
        self, chain, inputs, streaming_handler, testing
    ):
        que = Queue()

        t = Thread(
            target=self._execute_chain,
            args=(chain, inputs, que, streaming_handler),
        )
        t.start()

        if self.llm_loader.streamer.for_huggingface:
            count = (
                2
                if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
                else 1
            )

            while count > 0:
                try:
                    for token in self.llm_loader.streamer:
                        if not testing:
                            streaming_handler.on_llm_new_token(token)

                    self.llm_loader.streamer.reset()
                    count -= 1
                except Exception:
                    if not testing:
                        print("nothing generated yet - retry in 0.5s")
                    time.sleep(0.5)

        t.join()
        return que.get()