Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
·
66a4e8f
1
Parent(s):
b606edb
feat: add stream
Browse files- main.py +2 -2
- run_localGPT.py +3 -7
main.py
CHANGED
@@ -14,9 +14,9 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
14 |
from langchain.prompts import PromptTemplate
|
15 |
from langchain.memory import ConversationBufferMemory
|
16 |
|
|
|
17 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
18 |
from run_localGPT import load_model
|
19 |
-
from prompt_template_utils import get_prompt_template
|
20 |
|
21 |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
22 |
from langchain.vectorstores import Chroma
|
@@ -45,7 +45,7 @@ DB = Chroma(
|
|
45 |
|
46 |
RETRIEVER = DB.as_retriever()
|
47 |
|
48 |
-
LLM
|
49 |
|
50 |
template = """you are a helpful, respectful and honest assistant.
|
51 |
Your name is Katara llma. You should only use the source documents provided to answer the questions.
|
|
|
14 |
from langchain.prompts import PromptTemplate
|
15 |
from langchain.memory import ConversationBufferMemory
|
16 |
|
17 |
+
|
18 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
19 |
from run_localGPT import load_model
|
|
|
20 |
|
21 |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
22 |
from langchain.vectorstores import Chroma
|
|
|
45 |
|
46 |
RETRIEVER = DB.as_retriever()
|
47 |
|
48 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
|
49 |
|
50 |
template = """you are a helpful, respectful and honest assistant.
|
51 |
Your name is Katara llma. You should only use the source documents provided to answer the questions.
|
run_localGPT.py
CHANGED
@@ -10,8 +10,6 @@ from langchain.callbacks.manager import CallbackManager
|
|
10 |
|
11 |
torch.set_grad_enabled(False)
|
12 |
|
13 |
-
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
14 |
-
|
15 |
from prompt_template_utils import get_prompt_template
|
16 |
|
17 |
from langchain.vectorstores import Chroma
|
@@ -38,7 +36,7 @@ from constants import (
|
|
38 |
|
39 |
|
40 |
|
41 |
-
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
42 |
"""
|
43 |
Select a model for text generation using the HuggingFace library.
|
44 |
If you are running this for the first time, it will download a model for you.
|
@@ -91,15 +89,13 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
|
91 |
top_k=40,
|
92 |
repetition_penalty=1.0,
|
93 |
generation_config=generation_config,
|
94 |
-
|
95 |
-
num_return_sequences=1,
|
96 |
-
eos_token_id=tokenizer.eos_token_id
|
97 |
)
|
98 |
|
99 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
100 |
logging.info("Local LLM Loaded")
|
101 |
|
102 |
-
return
|
103 |
|
104 |
|
105 |
def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|
|
|
10 |
|
11 |
torch.set_grad_enabled(False)
|
12 |
|
|
|
|
|
13 |
from prompt_template_utils import get_prompt_template
|
14 |
|
15 |
from langchain.vectorstores import Chroma
|
|
|
36 |
|
37 |
|
38 |
|
39 |
+
def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
|
40 |
"""
|
41 |
Select a model for text generation using the HuggingFace library.
|
42 |
If you are running this for the first time, it will download a model for you.
|
|
|
89 |
top_k=40,
|
90 |
repetition_penalty=1.0,
|
91 |
generation_config=generation_config,
|
92 |
+
callback=[StreamingStdOutCallbackHandler()]
|
|
|
|
|
93 |
)
|
94 |
|
95 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
96 |
logging.info("Local LLM Loaded")
|
97 |
|
98 |
+
return local_llm
|
99 |
|
100 |
|
101 |
def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|