WIP
Browse files- app_modules/llm_inference.py +96 -0
- app_modules/llm_loader.py +1 -0
- app_modules/llm_qa_chain.py +23 -0
app_modules/llm_inference.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import urllib
|
5 |
+
from queue import Queue
|
6 |
+
from threading import Thread
|
7 |
+
|
8 |
+
from langchain.callbacks.tracers import LangChainTracer
|
9 |
+
from langchain.chains.base import Chain
|
10 |
+
|
11 |
+
from app_modules.llm_loader import *
|
12 |
+
from app_modules.utils import *
|
13 |
+
|
14 |
+
|
15 |
+
class LLMInference(metaclass=abc.ABCMeta):
|
16 |
+
llm_loader: LLMLoader
|
17 |
+
chain: Chain
|
18 |
+
|
19 |
+
def __init__(self, llm_loader):
|
20 |
+
self.llm_loader = llm_loader
|
21 |
+
self.chain = None
|
22 |
+
|
23 |
+
@abc.abstractmethod
|
24 |
+
def create_chain(self) -> Chain:
|
25 |
+
pass
|
26 |
+
|
27 |
+
def get_chain(self, tracing: bool = False) -> Chain:
|
28 |
+
if self.chain is None:
|
29 |
+
if tracing:
|
30 |
+
tracer = LangChainTracer()
|
31 |
+
tracer.load_default_session()
|
32 |
+
|
33 |
+
self.chain = self.create_chain()
|
34 |
+
|
35 |
+
return self.chain
|
36 |
+
|
37 |
+
def call_chain(
|
38 |
+
self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
|
39 |
+
):
|
40 |
+
print(inputs)
|
41 |
+
|
42 |
+
if self.llm_loader.streamer is not None and isinstance(
|
43 |
+
self.llm_loader.streamer, TextIteratorStreamer
|
44 |
+
):
|
45 |
+
self.llm_loader.streamer.reset(q)
|
46 |
+
|
47 |
+
chain = self.get_chain(tracing)
|
48 |
+
result = (
|
49 |
+
self._run_qa_chain(
|
50 |
+
chain,
|
51 |
+
inputs,
|
52 |
+
streaming_handler,
|
53 |
+
)
|
54 |
+
if streaming_handler is not None
|
55 |
+
else chain(inputs)
|
56 |
+
)
|
57 |
+
|
58 |
+
result["answer"] = remove_extra_spaces(result["answer"])
|
59 |
+
|
60 |
+
base_url = os.environ.get("PDF_FILE_BASE_URL")
|
61 |
+
if base_url is not None and len(base_url) > 0:
|
62 |
+
documents = result["source_documents"]
|
63 |
+
for doc in documents:
|
64 |
+
source = doc.metadata["source"]
|
65 |
+
title = source.split("/")[-1]
|
66 |
+
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
|
67 |
+
|
68 |
+
return result
|
69 |
+
|
70 |
+
def _run_qa_chain(self, qa, inputs, streaming_handler):
|
71 |
+
que = Queue()
|
72 |
+
|
73 |
+
t = Thread(
|
74 |
+
target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])),
|
75 |
+
args=(qa, inputs, que, streaming_handler),
|
76 |
+
)
|
77 |
+
t.start()
|
78 |
+
|
79 |
+
if self.llm_loader.streamer is not None and isinstance(
|
80 |
+
self.llm_loader.streamer, TextIteratorStreamer
|
81 |
+
):
|
82 |
+
count = 2 if len(inputs.get("chat_history")) > 0 else 1
|
83 |
+
|
84 |
+
while count > 0:
|
85 |
+
try:
|
86 |
+
for token in self.llm_loader.streamer:
|
87 |
+
streaming_handler.on_llm_new_token(token)
|
88 |
+
|
89 |
+
self.llm_loader.streamer.reset()
|
90 |
+
count -= 1
|
91 |
+
except Exception:
|
92 |
+
print("nothing generated yet - retry in 0.5s")
|
93 |
+
time.sleep(0.5)
|
94 |
+
|
95 |
+
t.join()
|
96 |
+
return que.get()
|
app_modules/llm_loader.py
CHANGED
@@ -88,6 +88,7 @@ class LLMLoader:
|
|
88 |
llm_model_type: str
|
89 |
llm: any
|
90 |
streamer: any
|
|
|
91 |
|
92 |
def __init__(self, llm_model_type):
|
93 |
self.llm_model_type = llm_model_type
|
|
|
88 |
llm_model_type: str
|
89 |
llm: any
|
90 |
streamer: any
|
91 |
+
max_tokens_limit: int
|
92 |
|
93 |
def __init__(self, llm_model_type):
|
94 |
self.llm_model_type = llm_model_type
|
app_modules/llm_qa_chain.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import ConversationalRetrievalChain
|
2 |
+
from langchain.chains.base import Chain
|
3 |
+
from langchain.vectorstores.base import VectorStore
|
4 |
+
|
5 |
+
from app_modules.llm_inference import LLMInference
|
6 |
+
|
7 |
+
|
8 |
+
class QAChain(LLMInference):
|
9 |
+
vectorstore: VectorStore
|
10 |
+
|
11 |
+
def __init__(self, vectorstore, llm_loader: int = 2048):
|
12 |
+
super.__init__(llm_loader)
|
13 |
+
self.vectorstore = vectorstore
|
14 |
+
|
15 |
+
def create_chain(self) -> Chain:
|
16 |
+
qa = ConversationalRetrievalChain.from_llm(
|
17 |
+
self.llm_loader.llm,
|
18 |
+
self.vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
|
19 |
+
max_tokens_limit=self.llm_loader.max_tokens_limit,
|
20 |
+
return_source_documents=True,
|
21 |
+
)
|
22 |
+
|
23 |
+
return qa
|