learn-ai / app_modules /llm_inference.py
dh-mc's picture
fixed bug in gradio app
2826548
raw
history blame
2.87 kB
import abc
import os
import time
import urllib
from queue import Queue
from threading import Thread
from langchain.callbacks.tracers import LangChainTracer
from langchain.chains.base import Chain
from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces
class LLMInference(metaclass=abc.ABCMeta):
llm_loader: LLMLoader
chain: Chain
def __init__(self, llm_loader):
self.llm_loader = llm_loader
self.chain = None
@abc.abstractmethod
def create_chain(self) -> Chain:
pass
def get_chain(self, tracing: bool = False) -> Chain:
if self.chain is None:
if tracing:
tracer = LangChainTracer()
tracer.load_default_session()
self.chain = self.create_chain()
return self.chain
def call_chain(
self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
):
print(inputs)
if self.llm_loader.streamer is not None and isinstance(
self.llm_loader.streamer, TextIteratorStreamer
):
self.llm_loader.streamer.reset(q)
chain = self.get_chain(tracing)
result = (
self._run_qa_chain(
chain,
inputs,
streaming_handler,
)
if streaming_handler is not None
else chain(inputs)
)
if "answer" in result:
result["answer"] = remove_extra_spaces(result["answer"])
base_url = os.environ.get("PDF_FILE_BASE_URL")
if base_url is not None and len(base_url) > 0:
documents = result["source_documents"]
for doc in documents:
source = doc.metadata["source"]
title = source.split("/")[-1]
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
return result
def _run_qa_chain(self, qa, inputs, streaming_handler):
que = Queue()
t = Thread(
target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])),
args=(qa, inputs, que, streaming_handler),
)
t.start()
if self.llm_loader.streamer is not None and isinstance(
self.llm_loader.streamer, TextIteratorStreamer
):
count = 2 if len(inputs.get("chat_history")) > 0 else 1
while count > 0:
try:
for token in self.llm_loader.streamer:
streaming_handler.on_llm_new_token(token)
self.llm_loader.streamer.reset()
count -= 1
except Exception:
print("nothing generated yet - retry in 0.5s")
time.sleep(0.5)
t.join()
return que.get()