Daniel Marques commited on
Commit
e72e226
1 Parent(s): dc8d635

fix: add callback

Browse files
Files changed (2) hide show
  1. load_models.py +2 -2
  2. main.py +0 -3
load_models.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  from auto_gptq import AutoGPTQForCausalLM
4
  from huggingface_hub import hf_hub_download
5
  from langchain.llms import LlamaCpp, HuggingFacePipeline
 
6
 
7
  from transformers import (
8
  AutoModelForCausalLM,
@@ -204,8 +205,6 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
204
 
205
  streamer = TextStreamer(tokenizer, skip_prompt=True)
206
 
207
- logging.info(streamer)
208
-
209
  pipe = pipeline(
210
  "text-generation",
211
  model=model,
@@ -217,6 +216,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
217
  repetition_penalty=1.0,
218
  generation_config=generation_config,
219
  streamer=streamer
 
220
  )
221
 
222
  local_llm = HuggingFacePipeline(pipeline=pipe)
 
3
  from auto_gptq import AutoGPTQForCausalLM
4
  from huggingface_hub import hf_hub_download
5
  from langchain.llms import LlamaCpp, HuggingFacePipeline
6
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
7
 
8
  from transformers import (
9
  AutoModelForCausalLM,
 
205
 
206
  streamer = TextStreamer(tokenizer, skip_prompt=True)
207
 
 
 
208
  pipe = pipeline(
209
  "text-generation",
210
  model=model,
 
216
  repetition_penalty=1.0,
217
  generation_config=generation_config,
218
  streamer=streamer
219
+ callbacks=[StreamingStdOutCallbackHandler()]
220
  )
221
 
222
  local_llm = HuggingFacePipeline(pipeline=pipe)
main.py CHANGED
@@ -179,9 +179,6 @@ async def predict(data: Predict):
179
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
180
  )
181
 
182
-
183
-
184
-
185
  return {"response": prompt_response_dict}
186
  else:
187
  raise HTTPException(status_code=400, detail="Prompt Incorrect")
 
179
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
180
  )
181
 
 
 
 
182
  return {"response": prompt_response_dict}
183
  else:
184
  raise HTTPException(status_code=400, detail="Prompt Incorrect")