fixed bug for lc_serve
Browse files- app_modules/init.py +2 -2
- app_modules/llm_loader.py +4 -2
- server.py +1 -1
app_modules/init.py
CHANGED
@@ -23,7 +23,7 @@ load_dotenv(found_dotenv, override=False)
|
|
23 |
init_settings()
|
24 |
|
25 |
|
26 |
-
def app_init():
|
27 |
# https://github.com/huggingface/transformers/issues/17611
|
28 |
os.environ["CURL_CA_BUNDLE"] = ""
|
29 |
|
@@ -69,7 +69,7 @@ def app_init():
|
|
69 |
print(f"Completed in {end - start:.3f}s")
|
70 |
|
71 |
start = timer()
|
72 |
-
llm_loader = LLMLoader(llm_model_type)
|
73 |
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
74 |
qa_chain = QAChain(vectorstore, llm_loader)
|
75 |
end = timer()
|
|
|
23 |
init_settings()
|
24 |
|
25 |
|
26 |
+
def app_init(lc_serve: bool = False):
|
27 |
# https://github.com/huggingface/transformers/issues/17611
|
28 |
os.environ["CURL_CA_BUNDLE"] = ""
|
29 |
|
|
|
69 |
print(f"Completed in {end - start:.3f}s")
|
70 |
|
71 |
start = timer()
|
72 |
+
llm_loader = LLMLoader(llm_model_type, lc_serve)
|
73 |
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
74 |
qa_chain = QAChain(vectorstore, llm_loader)
|
75 |
end = timer()
|
app_modules/llm_loader.py
CHANGED
@@ -90,10 +90,12 @@ class LLMLoader:
|
|
90 |
streamer: any
|
91 |
max_tokens_limit: int
|
92 |
|
93 |
-
def __init__(
|
|
|
|
|
94 |
self.llm_model_type = llm_model_type
|
95 |
self.llm = None
|
96 |
-
self.streamer = None
|
97 |
self.max_tokens_limit = max_tokens_limit
|
98 |
self.search_kwargs = {"k": 4}
|
99 |
|
|
|
90 |
streamer: any
|
91 |
max_tokens_limit: int
|
92 |
|
93 |
+
def __init__(
|
94 |
+
self, llm_model_type, max_tokens_limit: int = 2048, lc_serve: bool = False
|
95 |
+
):
|
96 |
self.llm_model_type = llm_model_type
|
97 |
self.llm = None
|
98 |
+
self.streamer = None if lc_serve else TextIteratorStreamer("")
|
99 |
self.max_tokens_limit = max_tokens_limit
|
100 |
self.search_kwargs = {"k": 4}
|
101 |
|
server.py
CHANGED
@@ -11,7 +11,7 @@ from app_modules.init import app_init
|
|
11 |
from app_modules.llm_chat_chain import ChatChain
|
12 |
from app_modules.utils import print_llm_response
|
13 |
|
14 |
-
llm_loader, qa_chain = app_init()
|
15 |
|
16 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
|
17 |
|
|
|
11 |
from app_modules.llm_chat_chain import ChatChain
|
12 |
from app_modules.utils import print_llm_response
|
13 |
|
14 |
+
llm_loader, qa_chain = app_init(True)
|
15 |
|
16 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
|
17 |
|