|
import abc |
|
import os |
|
import time |
|
import urllib |
|
from queue import Queue |
|
from threading import Thread |
|
|
|
from langchain.callbacks.tracers import LangChainTracer |
|
from langchain.chains.base import Chain |
|
|
|
from app_modules.llm_loader import LLMLoader, TextIteratorStreamer |
|
from app_modules.utils import remove_extra_spaces |
|
|
|
|
|
class LLMInference(metaclass=abc.ABCMeta): |
|
llm_loader: LLMLoader |
|
chain: Chain |
|
|
|
def __init__(self, llm_loader): |
|
self.llm_loader = llm_loader |
|
self.chain = None |
|
|
|
@abc.abstractmethod |
|
def create_chain(self) -> Chain: |
|
pass |
|
|
|
def get_chain(self, tracing: bool = False) -> Chain: |
|
if self.chain is None: |
|
if tracing: |
|
tracer = LangChainTracer() |
|
tracer.load_default_session() |
|
|
|
self.chain = self.create_chain() |
|
|
|
return self.chain |
|
|
|
def call_chain( |
|
self, inputs, streaming_handler, q: Queue = None, tracing: bool = False |
|
): |
|
print(inputs) |
|
|
|
if self.llm_loader.streamer is not None and isinstance( |
|
self.llm_loader.streamer, TextIteratorStreamer |
|
): |
|
self.llm_loader.streamer.reset(q) |
|
|
|
chain = self.get_chain(tracing) |
|
result = ( |
|
self._run_qa_chain( |
|
chain, |
|
inputs, |
|
streaming_handler, |
|
) |
|
if streaming_handler is not None |
|
else chain(inputs) |
|
) |
|
|
|
if "answer" in result: |
|
result["answer"] = remove_extra_spaces(result["answer"]) |
|
|
|
base_url = os.environ.get("PDF_FILE_BASE_URL") |
|
if base_url is not None and len(base_url) > 0: |
|
documents = result["source_documents"] |
|
for doc in documents: |
|
source = doc.metadata["source"] |
|
title = source.split("/")[-1] |
|
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}" |
|
|
|
return result |
|
|
|
def _run_qa_chain(self, qa, inputs, streaming_handler): |
|
que = Queue() |
|
|
|
t = Thread( |
|
target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])), |
|
args=(qa, inputs, que, streaming_handler), |
|
) |
|
t.start() |
|
|
|
if self.llm_loader.streamer is not None and isinstance( |
|
self.llm_loader.streamer, TextIteratorStreamer |
|
): |
|
count = 2 if len(inputs.get("chat_history")) > 0 else 1 |
|
|
|
while count > 0: |
|
try: |
|
for token in self.llm_loader.streamer: |
|
streaming_handler.on_llm_new_token(token) |
|
|
|
self.llm_loader.streamer.reset() |
|
count -= 1 |
|
except Exception: |
|
print("nothing generated yet - retry in 0.5s") |
|
time.sleep(0.5) |
|
|
|
t.join() |
|
return que.get() |
|
|