inflaton commited on
Commit
328b268
1 Parent(s): 81a80b7
app_modules/llm_inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import os
3
+ import time
4
+ import urllib
5
+ from queue import Queue
6
+ from threading import Thread
7
+
8
+ from langchain.callbacks.tracers import LangChainTracer
9
+ from langchain.chains.base import Chain
10
+
11
+ from app_modules.llm_loader import *
12
+ from app_modules.utils import *
13
+
14
+
15
+ class LLMInference(metaclass=abc.ABCMeta):
16
+ llm_loader: LLMLoader
17
+ chain: Chain
18
+
19
+ def __init__(self, llm_loader):
20
+ self.llm_loader = llm_loader
21
+ self.chain = None
22
+
23
+ @abc.abstractmethod
24
+ def create_chain(self) -> Chain:
25
+ pass
26
+
27
+ def get_chain(self, tracing: bool = False) -> Chain:
28
+ if self.chain is None:
29
+ if tracing:
30
+ tracer = LangChainTracer()
31
+ tracer.load_default_session()
32
+
33
+ self.chain = self.create_chain()
34
+
35
+ return self.chain
36
+
37
+ def call_chain(
38
+ self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
39
+ ):
40
+ print(inputs)
41
+
42
+ if self.llm_loader.streamer is not None and isinstance(
43
+ self.llm_loader.streamer, TextIteratorStreamer
44
+ ):
45
+ self.llm_loader.streamer.reset(q)
46
+
47
+ chain = self.get_chain(tracing)
48
+ result = (
49
+ self._run_qa_chain(
50
+ chain,
51
+ inputs,
52
+ streaming_handler,
53
+ )
54
+ if streaming_handler is not None
55
+ else chain(inputs)
56
+ )
57
+
58
+ result["answer"] = remove_extra_spaces(result["answer"])
59
+
60
+ base_url = os.environ.get("PDF_FILE_BASE_URL")
61
+ if base_url is not None and len(base_url) > 0:
62
+ documents = result["source_documents"]
63
+ for doc in documents:
64
+ source = doc.metadata["source"]
65
+ title = source.split("/")[-1]
66
+ doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
67
+
68
+ return result
69
+
70
+ def _run_qa_chain(self, qa, inputs, streaming_handler):
71
+ que = Queue()
72
+
73
+ t = Thread(
74
+ target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])),
75
+ args=(qa, inputs, que, streaming_handler),
76
+ )
77
+ t.start()
78
+
79
+ if self.llm_loader.streamer is not None and isinstance(
80
+ self.llm_loader.streamer, TextIteratorStreamer
81
+ ):
82
+ count = 2 if len(inputs.get("chat_history")) > 0 else 1
83
+
84
+ while count > 0:
85
+ try:
86
+ for token in self.llm_loader.streamer:
87
+ streaming_handler.on_llm_new_token(token)
88
+
89
+ self.llm_loader.streamer.reset()
90
+ count -= 1
91
+ except Exception:
92
+ print("nothing generated yet - retry in 0.5s")
93
+ time.sleep(0.5)
94
+
95
+ t.join()
96
+ return que.get()
app_modules/llm_loader.py CHANGED
@@ -88,6 +88,7 @@ class LLMLoader:
88
  llm_model_type: str
89
  llm: any
90
  streamer: any
 
91
 
92
  def __init__(self, llm_model_type):
93
  self.llm_model_type = llm_model_type
 
88
  llm_model_type: str
89
  llm: any
90
  streamer: any
91
+ max_tokens_limit: int
92
 
93
  def __init__(self, llm_model_type):
94
  self.llm_model_type = llm_model_type
app_modules/llm_qa_chain.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import ConversationalRetrievalChain
2
+ from langchain.chains.base import Chain
3
+ from langchain.vectorstores.base import VectorStore
4
+
5
+ from app_modules.llm_inference import LLMInference
6
+
7
+
8
+ class QAChain(LLMInference):
9
+ vectorstore: VectorStore
10
+
11
+ def __init__(self, vectorstore, llm_loader: int = 2048):
12
+ super.__init__(llm_loader)
13
+ self.vectorstore = vectorstore
14
+
15
+ def create_chain(self) -> Chain:
16
+ qa = ConversationalRetrievalChain.from_llm(
17
+ self.llm_loader.llm,
18
+ self.vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
19
+ max_tokens_limit=self.llm_loader.max_tokens_limit,
20
+ return_source_documents=True,
21
+ )
22
+
23
+ return qa