dh-mc commited on
Commit
30bf870
1 Parent(s): 4359eb6

fixed bug for lc_serve

Browse files
Files changed (3) hide show
  1. app_modules/init.py +2 -2
  2. app_modules/llm_loader.py +4 -2
  3. server.py +1 -1
app_modules/init.py CHANGED
@@ -23,7 +23,7 @@ load_dotenv(found_dotenv, override=False)
23
  init_settings()
24
 
25
 
26
- def app_init():
27
  # https://github.com/huggingface/transformers/issues/17611
28
  os.environ["CURL_CA_BUNDLE"] = ""
29
 
@@ -69,7 +69,7 @@ def app_init():
69
  print(f"Completed in {end - start:.3f}s")
70
 
71
  start = timer()
72
- llm_loader = LLMLoader(llm_model_type)
73
  llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
74
  qa_chain = QAChain(vectorstore, llm_loader)
75
  end = timer()
 
23
  init_settings()
24
 
25
 
26
+ def app_init(lc_serve: bool = False):
27
  # https://github.com/huggingface/transformers/issues/17611
28
  os.environ["CURL_CA_BUNDLE"] = ""
29
 
 
69
  print(f"Completed in {end - start:.3f}s")
70
 
71
  start = timer()
72
+ llm_loader = LLMLoader(llm_model_type, lc_serve)
73
  llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
74
  qa_chain = QAChain(vectorstore, llm_loader)
75
  end = timer()
app_modules/llm_loader.py CHANGED
@@ -90,10 +90,12 @@ class LLMLoader:
90
  streamer: any
91
  max_tokens_limit: int
92
 
93
- def __init__(self, llm_model_type, max_tokens_limit: int = 2048):
 
 
94
  self.llm_model_type = llm_model_type
95
  self.llm = None
96
- self.streamer = None
97
  self.max_tokens_limit = max_tokens_limit
98
  self.search_kwargs = {"k": 4}
99
 
 
90
  streamer: any
91
  max_tokens_limit: int
92
 
93
+ def __init__(
94
+ self, llm_model_type, max_tokens_limit: int = 2048, lc_serve: bool = False
95
+ ):
96
  self.llm_model_type = llm_model_type
97
  self.llm = None
98
+ self.streamer = None if lc_serve else TextIteratorStreamer("")
99
  self.max_tokens_limit = max_tokens_limit
100
  self.search_kwargs = {"k": 4}
101
 
server.py CHANGED
@@ -11,7 +11,7 @@ from app_modules.init import app_init
11
  from app_modules.llm_chat_chain import ChatChain
12
  from app_modules.utils import print_llm_response
13
 
14
- llm_loader, qa_chain = app_init()
15
 
16
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
17
 
 
11
  from app_modules.llm_chat_chain import ChatChain
12
  from app_modules.utils import print_llm_response
13
 
14
+ llm_loader, qa_chain = app_init(True)
15
 
16
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
17