dkdaniz commited on
Commit
3579c4a
1 Parent(s): 5ccd0e0

Update run_localGPT.py

Browse files
Files changed (1) hide show
  1. run_localGPT.py +212 -106
run_localGPT.py CHANGED
@@ -1,95 +1,163 @@
1
- import logging
2
  import os
3
- from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
4
-
5
  import click
6
  import torch
7
- from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceInstructEmbeddings
9
- from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
 
 
 
 
 
 
 
 
 
 
10
  from langchain.vectorstores import Chroma
 
 
 
 
11
 
12
- torch.cuda.empty_cache()
13
- torch.cuda.memory_summary(device=None, abbreviated=False)
 
 
 
14
 
15
  from constants import (
16
- CHROMA_SETTINGS,
17
- DOCUMENT_MAP,
18
  EMBEDDING_MODEL_NAME,
19
- INGEST_THREADS,
20
  PERSIST_DIRECTORY,
21
- SOURCE_DIRECTORY,
 
 
 
22
  )
23
 
24
 
25
- def load_single_document(file_path: str) -> Document:
26
- # Loads a single document from a file path
27
- file_extension = os.path.splitext(file_path)[1]
28
- loader_class = DOCUMENT_MAP.get(file_extension)
29
- if loader_class:
30
- loader = loader_class(file_path)
31
- else:
32
- raise ValueError("Document type is undefined")
33
- return loader.load()[0]
34
-
35
-
36
- def load_document_batch(filepaths):
37
- logging.info("Loading document batch")
38
- # create a thread pool
39
- with ThreadPoolExecutor(len(filepaths)) as exe:
40
- # load files
41
- futures = [exe.submit(load_single_document, name) for name in filepaths]
42
- # collect data
43
- data_list = [future.result() for future in futures]
44
- # return data and file paths
45
- return (data_list, filepaths)
46
-
47
-
48
- def load_documents(source_dir: str) -> list[Document]:
49
- # Loads all documents from the source documents directory, including nested folders
50
- paths = []
51
- for root, _, files in os.walk(source_dir):
52
- for file_name in files:
53
- file_extension = os.path.splitext(file_name)[1]
54
- source_file_path = os.path.join(root, file_name)
55
- if file_extension in DOCUMENT_MAP.keys():
56
- paths.append(source_file_path)
57
-
58
- # Have at least one worker and at most INGEST_THREADS workers
59
- n_workers = min(INGEST_THREADS, max(len(paths), 1))
60
- chunksize = round(len(paths) / n_workers)
61
- docs = []
62
- with ProcessPoolExecutor(n_workers) as executor:
63
- futures = []
64
- # split the load operations into chunks
65
- for i in range(0, len(paths), chunksize):
66
- # select a chunk of filenames
67
- filepaths = paths[i : (i + chunksize)]
68
- # submit the task
69
- future = executor.submit(load_document_batch, filepaths)
70
- futures.append(future)
71
- # process all results
72
- for future in as_completed(futures):
73
- # open the file and load the data
74
- contents, _ = future.result()
75
- docs.extend(contents)
76
-
77
- return docs
78
-
79
-
80
- def split_documents(documents: list[Document]) -> tuple[list[Document], list[Document]]:
81
- # Splits documents for correct Text Splitter
82
- text_docs, python_docs = [], []
83
- for doc in documents:
84
- file_extension = os.path.splitext(doc.metadata["source"])[1]
85
- if file_extension == ".py":
86
- python_docs.append(doc)
87
  else:
88
- text_docs.append(doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- return text_docs, python_docs
91
 
92
 
 
93
  @click.command()
94
  @click.option(
95
  "--device_type",
@@ -119,40 +187,78 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc
119
  ),
120
  help="Device to run on. (Default is cuda)",
121
  )
122
- def main(device_type):
123
- # Load documents and split in chunks
124
- logging.info(f"Loading documents from {SOURCE_DIRECTORY}")
125
- documents = load_documents(SOURCE_DIRECTORY)
126
- text_documents, python_documents = split_documents(documents)
127
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
128
- python_splitter = RecursiveCharacterTextSplitter.from_language(
129
- language=Language.PYTHON, chunk_size=880, chunk_overlap=200
130
- )
131
- texts = text_splitter.split_documents(text_documents)
132
- texts.extend(python_splitter.split_documents(python_documents))
133
- logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}")
134
- logging.info(f"Split into {len(texts)} chunks of text")
135
-
136
- # Create embeddings
137
- embeddings = HuggingFaceInstructEmbeddings(
138
- model_name=EMBEDDING_MODEL_NAME,
139
- model_kwargs={"device": device_type},
140
- )
141
- # change the embedding type here if you are running into issues.
142
- # These are much smaller embeddings and will work for most appications
143
- # If you use HuggingFaceEmbeddings, make sure to also use the same in the
144
- # run_localGPT.py file.
145
 
146
- # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
 
 
147
 
148
- db = Chroma.from_documents(
149
- texts,
150
- embeddings,
151
- persist_directory=PERSIST_DIRECTORY,
152
- client_settings=CHROMA_SETTINGS,
153
 
154
- )
155
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  if __name__ == "__main__":
 
 
1
  import os
2
+ import logging
 
3
  import click
4
  import torch
5
+ from langchain.chains import RetrievalQA
6
  from langchain.embeddings import HuggingFaceInstructEmbeddings
7
+ from langchain.llms import HuggingFacePipeline
8
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
9
+ from langchain.callbacks.manager import CallbackManager
10
+
11
+ torch.set_grad_enabled(False)
12
+
13
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
14
+
15
+ from prompt_template_utils import get_prompt_template
16
+
17
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
18
  from langchain.vectorstores import Chroma
19
+ from transformers import (
20
+ GenerationConfig,
21
+ pipeline,
22
+ )
23
 
24
+ from load_models import (
25
+ load_quantized_model_gguf_ggml,
26
+ load_quantized_model_qptq,
27
+ load_full_model,
28
+ )
29
 
30
  from constants import (
 
 
31
  EMBEDDING_MODEL_NAME,
 
32
  PERSIST_DIRECTORY,
33
+ MODEL_ID,
34
+ MODEL_BASENAME,
35
+ MAX_NEW_TOKENS,
36
+ MODELS_PATH,
37
  )
38
 
39
 
40
+ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
41
+ """
42
+ Select a model for text generation using the HuggingFace library.
43
+ If you are running this for the first time, it will download a model for you.
44
+ subsequent runs will use the model from the disk.
45
+
46
+ Args:
47
+ device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU.
48
+ model_id (str): Identifier of the model to load from HuggingFace's model hub.
49
+ model_basename (str, optional): Basename of the model if using quantized models.
50
+ Defaults to None.
51
+
52
+ Returns:
53
+ HuggingFacePipeline: A pipeline object for text generation using the loaded model.
54
+
55
+ Raises:
56
+ ValueError: If an unsupported model or device type is provided.
57
+ """
58
+ logging.info(f"Loading Model: {model_id}, on: {device_type}")
59
+ logging.info("This action can take a few minutes!")
60
+
61
+ if model_basename is not None:
62
+ if ".gguf" in model_basename.lower():
63
+ llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
64
+ return llm
65
+ elif ".ggml" in model_basename.lower():
66
+ model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  else:
68
+ model, tokenizer = load_quantized_model_qptq(model_id, model_basename, device_type, LOGGING)
69
+ else:
70
+ model, tokenizer = load_full_model(model_id, model_basename, device_type, LOGGING)
71
+
72
+ # Load configuration from the model to avoid warnings
73
+ generation_config = GenerationConfig.from_pretrained(model_id)
74
+ # see here for details:
75
+ # https://huggingface.co/docs/transformers/
76
+ # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
77
+
78
+ # Create a pipeline for text generation
79
+ pipe = pipeline(
80
+ "text-generation",
81
+ model=model,
82
+ tokenizer=tokenizer,
83
+ max_length=50,
84
+ temperature=0.2,
85
+ # top_p=0.95,
86
+ repetition_penalty=1.15,
87
+ generation_config=generation_config,
88
+ )
89
+
90
+ local_llm = HuggingFacePipeline(pipeline=pipe)
91
+ logging.info("Local LLM Loaded")
92
+
93
+ return local_llm
94
+
95
+
96
+ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
97
+ """
98
+ Initializes and returns a retrieval-based Question Answering (QA) pipeline.
99
+
100
+ This function sets up a QA system that retrieves relevant information using embeddings
101
+ from the HuggingFace library. It then answers questions based on the retrieved information.
102
+
103
+ Parameters:
104
+ - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'cuda', etc.
105
+ - use_history (bool): Flag to determine whether to use chat history or not.
106
+
107
+ Returns:
108
+ - RetrievalQA: An initialized retrieval-based QA system.
109
+
110
+ Notes:
111
+ - The function uses embeddings from the HuggingFace library, either instruction-based or regular.
112
+ - The Chroma class is used to load a vector store containing pre-computed embeddings.
113
+ - The retriever fetches relevant documents or data based on a query.
114
+ - The prompt and memory, obtained from the `get_prompt_template` function, might be used in the QA system.
115
+ - The model is loaded onto the specified device using its ID and basename.
116
+ - The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
117
+ """
118
+
119
+ embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
120
+ # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
121
+ # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
122
+
123
+ # load the vectorstore
124
+ db = Chroma(
125
+ persist_directory=PERSIST_DIRECTORY,
126
+ embedding_function=embeddings,
127
+ )
128
+ retriever = db.as_retriever()
129
+
130
+ # get the prompt template and memory if set by the user.
131
+ prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
132
+
133
+ # load the llm pipeline
134
+ llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
135
+
136
+ if use_history:
137
+ qa = RetrievalQA.from_chain_type(
138
+ llm=llm,
139
+ chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
140
+ retriever=retriever,
141
+ return_source_documents=True, # verbose=True,
142
+ callbacks=callback_manager,
143
+ chain_type_kwargs={"prompt": prompt, "memory": memory},
144
+ )
145
+ else:
146
+ qa = RetrievalQA.from_chain_type(
147
+ llm=llm,
148
+ chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
149
+ retriever=retriever,
150
+ return_source_documents=True, # verbose=True,
151
+ callbacks=callback_manager,
152
+ chain_type_kwargs={
153
+ "prompt": prompt,
154
+ },
155
+ )
156
 
157
+ return qa
158
 
159
 
160
+ # chose device typ to run on as well as to show source documents.
161
  @click.command()
162
  @click.option(
163
  "--device_type",
 
187
  ),
188
  help="Device to run on. (Default is cuda)",
189
  )
190
+ @click.option(
191
+ "--show_sources",
192
+ "-s",
193
+ is_flag=True,
194
+ help="Show sources along with answers (Default is False)",
195
+ )
196
+ @click.option(
197
+ "--use_history",
198
+ "-h",
199
+ is_flag=True,
200
+ help="Use history (Default is False)",
201
+ )
202
+ @click.option(
203
+ "--model_type",
204
+ default="llama",
205
+ type=click.Choice(
206
+ ["llama", "mistral", "non_llama"],
207
+ ),
208
+ help="model type, llama, mistral or non_llama",
209
+ )
210
+ def main(device_type, show_sources, use_history, model_type):
211
+ """
212
+ Implements the main information retrieval task for a localGPT.
213
 
214
+ This function sets up the QA system by loading the necessary embeddings, vectorstore, and LLM model.
215
+ It then enters an interactive loop where the user can input queries and receive answers. Optionally,
216
+ the source documents used to derive the answers can also be displayed.
217
 
218
+ Parameters:
219
+ - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'mps', 'cuda', etc.
220
+ - show_sources (bool): Flag to determine whether to display the source documents used for answering.
221
+ - use_history (bool): Flag to determine whether to use chat history or not.
 
222
 
223
+ Notes:
224
+ - Logging information includes the device type, whether source documents are displayed, and the use of history.
225
+ - If the models directory does not exist, it creates a new one to store models.
226
+ - The user can exit the interactive loop by entering "exit".
227
+ - The source documents are displayed if the show_sources flag is set to True.
228
+
229
+ """
230
+
231
+ logging.info(f"Running on: {device_type}")
232
+ logging.info(f"Display Source Documents set to: {show_sources}")
233
+ logging.info(f"Use history set to: {use_history}")
234
+
235
+ # check if models directory do not exist, create a new one and store models here.
236
+ if not os.path.exists(MODELS_PATH):
237
+ os.mkdir(MODELS_PATH)
238
+
239
+ qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type=model_type)
240
+ # Interactive questions and answers
241
+ while True:
242
+ query = input("\nEnter a query: ")
243
+ if query == "exit":
244
+ break
245
+ # Get the answer from the chain
246
+ res = qa(query)
247
+ answer, docs = res["result"], res["source_documents"]
248
+
249
+ # Print the result
250
+ print("\n\n> Question:")
251
+ print(query)
252
+ print("\n> Answer:")
253
+ print(answer)
254
+
255
+ if show_sources: # this is a flag that you can set to disable showing answers.
256
+ # # Print the relevant sources used for the answer
257
+ print("----------------------------------SOURCE DOCUMENTS---------------------------")
258
+ for document in docs:
259
+ print("\n> " + document.metadata["source"] + ":")
260
+ print(document.page_content)
261
+ print("----------------------------------SOURCE DOCUMENTS---------------------------")
262
 
263
 
264
  if __name__ == "__main__":