Daniel Marques commited on
Commit
198843f
1 Parent(s): 8fa0233

fix: add streamer

Browse files
Files changed (2) hide show
  1. load_models.py +31 -1
  2. main.py +5 -8
load_models.py CHANGED
@@ -1,9 +1,15 @@
1
  import torch
 
2
  import logging
 
 
3
  from auto_gptq import AutoGPTQForCausalLM
4
  from huggingface_hub import hf_hub_download
5
  from langchain.llms import LlamaCpp, HuggingFacePipeline
6
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
 
7
 
8
  from transformers import (
9
  AutoModelForCausalLM,
@@ -22,6 +28,29 @@ torch.set_grad_enabled(False)
22
  from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging, stream = False):
26
  """
27
  Load a GGUF/GGML quantized model using LlamaCpp.
@@ -66,6 +95,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
66
 
67
  #add stream
68
  kwargs["stream"] = stream
 
69
 
70
  return LlamaCpp(**kwargs)
71
  except:
@@ -220,7 +250,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
220
  repetition_penalty=1.0,
221
  generation_config=generation_config,
222
  streamer=streamer,
223
- callbacks=[StreamingStdOutCallbackHandler()]
224
  )
225
 
226
  local_llm = HuggingFacePipeline(pipeline=pipe)
 
1
  import torch
2
+ import asyncio
3
  import logging
4
+ from typing import Any, Dict, List
5
+
6
  from auto_gptq import AutoGPTQForCausalLM
7
  from huggingface_hub import hf_hub_download
8
  from langchain.llms import LlamaCpp, HuggingFacePipeline
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
+ from langchain.schema import LLMResult
11
+
12
+ from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
13
 
14
  from transformers import (
15
  AutoModelForCausalLM,
 
28
  from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
29
 
30
 
31
+ class MyCustomSyncHandler(BaseCallbackHandler):
32
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
33
+ print(f"Sync handler being called in a `thread_pool_executor`: token: {token}")
34
+
35
+ class MyCustomAsyncHandler(AsyncCallbackHandler):
36
+ """Async callback handler that can be used to handle callbacks from langchain."""
37
+
38
+ async def on_llm_start(
39
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
40
+ ) -> None:
41
+ """Run when chain starts running."""
42
+ print("zzzz....")
43
+ await asyncio.sleep(0.3)
44
+ class_name = serialized["name"]
45
+ print("Hi! I just woke up. Your llm is starting")
46
+
47
+ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
48
+ """Run when chain ends running."""
49
+ print("zzzz....")
50
+ await asyncio.sleep(0.3)
51
+ print("Hi! I just woke up. Your llm is ending")
52
+
53
+
54
  def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging, stream = False):
55
  """
56
  Load a GGUF/GGML quantized model using LlamaCpp.
 
95
 
96
  #add stream
97
  kwargs["stream"] = stream
98
+ kwargs["callbacks"] = [MyCustomSyncHandler(), MyCustomAsyncHandler()]
99
 
100
  return LlamaCpp(**kwargs)
101
  except:
 
250
  repetition_penalty=1.0,
251
  generation_config=generation_config,
252
  streamer=streamer,
253
+ callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()]
254
  )
255
 
256
  local_llm = HuggingFacePipeline(pipeline=pipe)
main.py CHANGED
@@ -42,10 +42,7 @@ DB = Chroma(
42
 
43
  RETRIEVER = DB.as_retriever()
44
 
45
- models = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
46
-
47
- LLM = models[0]
48
- STREAMER = models[1]
49
 
50
  template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
51
  You should only respond only topics that contains in documents use to training.
@@ -182,10 +179,10 @@ async def predict(data: Predict):
182
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
183
  )
184
 
185
- generated_text = ""
186
- for new_text in STREAMER:
187
- generated_text += new_text
188
- print(generated_text)
189
 
190
  return {"response": prompt_response_dict}
191
  else:
 
42
 
43
  RETRIEVER = DB.as_retriever()
44
 
45
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
 
 
 
46
 
47
  template = """you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
48
  You should only respond only topics that contains in documents use to training.
 
179
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
180
  )
181
 
182
+ # generated_text = ""
183
+ # for new_text in STREAMER:
184
+ # generated_text += new_text
185
+ # print(generated_text)
186
 
187
  return {"response": prompt_response_dict}
188
  else: