Daniel Marques commited on
Commit
66a4e8f
·
1 Parent(s): b606edb

feat: add stream

Browse files
Files changed (2) hide show
  1. main.py +2 -2
  2. run_localGPT.py +3 -7
main.py CHANGED
@@ -14,9 +14,9 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
14
  from langchain.prompts import PromptTemplate
15
  from langchain.memory import ConversationBufferMemory
16
 
 
17
  # from langchain.embeddings import HuggingFaceEmbeddings
18
  from run_localGPT import load_model
19
- from prompt_template_utils import get_prompt_template
20
 
21
  # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.vectorstores import Chroma
@@ -45,7 +45,7 @@ DB = Chroma(
45
 
46
  RETRIEVER = DB.as_retriever()
47
 
48
- LLM, StreamData = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
49
 
50
  template = """you are a helpful, respectful and honest assistant.
51
  Your name is Katara llma. You should only use the source documents provided to answer the questions.
 
14
  from langchain.prompts import PromptTemplate
15
  from langchain.memory import ConversationBufferMemory
16
 
17
+
18
  # from langchain.embeddings import HuggingFaceEmbeddings
19
  from run_localGPT import load_model
 
20
 
21
  # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.vectorstores import Chroma
 
45
 
46
  RETRIEVER = DB.as_retriever()
47
 
48
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
49
 
50
  template = """you are a helpful, respectful and honest assistant.
51
  Your name is Katara llma. You should only use the source documents provided to answer the questions.
run_localGPT.py CHANGED
@@ -10,8 +10,6 @@ from langchain.callbacks.manager import CallbackManager
10
 
11
  torch.set_grad_enabled(False)
12
 
13
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
14
-
15
  from prompt_template_utils import get_prompt_template
16
 
17
  from langchain.vectorstores import Chroma
@@ -38,7 +36,7 @@ from constants import (
38
 
39
 
40
 
41
- def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
42
  """
43
  Select a model for text generation using the HuggingFace library.
44
  If you are running this for the first time, it will download a model for you.
@@ -91,15 +89,13 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
91
  top_k=40,
92
  repetition_penalty=1.0,
93
  generation_config=generation_config,
94
- streamer=streamer,
95
- num_return_sequences=1,
96
- eos_token_id=tokenizer.eos_token_id
97
  )
98
 
99
  local_llm = HuggingFacePipeline(pipeline=pipe)
100
  logging.info("Local LLM Loaded")
101
 
102
- return (local_llm, streamer)
103
 
104
 
105
  def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
 
10
 
11
  torch.set_grad_enabled(False)
12
 
 
 
13
  from prompt_template_utils import get_prompt_template
14
 
15
  from langchain.vectorstores import Chroma
 
36
 
37
 
38
 
39
+ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
40
  """
41
  Select a model for text generation using the HuggingFace library.
42
  If you are running this for the first time, it will download a model for you.
 
89
  top_k=40,
90
  repetition_penalty=1.0,
91
  generation_config=generation_config,
92
+ callback=[StreamingStdOutCallbackHandler()]
 
 
93
  )
94
 
95
  local_llm = HuggingFacePipeline(pipeline=pipe)
96
  logging.info("Local LLM Loaded")
97
 
98
+ return local_llm
99
 
100
 
101
  def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):