Daniel Marques commited on
Commit
1d1dd8d
β€’
1 Parent(s): 48e8fbb

fix: add trupple

Browse files
Files changed (5) hide show
  1. README.md +1 -2
  2. load_models.py +73 -1
  3. main.py +4 -5
  4. run_localGPT.py +0 -273
  5. run_localGPT_API.py +0 -184
README.md CHANGED
@@ -5,5 +5,4 @@ sdk: docker
5
  emoji: πŸš€
6
  colorFrom: yellow
7
  colorTo: yellow
8
- # duplicated_from: radames/llama-cpp-python-cuda-gradio
9
- ---
 
5
  emoji: πŸš€
6
  colorFrom: yellow
7
  colorTo: yellow
8
+ ---
 
load_models.py CHANGED
@@ -1,14 +1,23 @@
1
  import torch
 
2
  from auto_gptq import AutoGPTQForCausalLM
3
  from huggingface_hub import hf_hub_download
4
- from langchain.llms import LlamaCpp
5
 
6
  from transformers import (
7
  AutoModelForCausalLM,
8
  AutoTokenizer,
9
  LlamaForCausalLM,
10
  LlamaTokenizer,
 
 
 
11
  )
 
 
 
 
 
12
  from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
13
 
14
 
@@ -149,3 +158,66 @@ def load_full_model(model_id, model_basename, device_type, logging):
149
  )
150
  model.tie_weights()
151
  return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import logging
3
  from auto_gptq import AutoGPTQForCausalLM
4
  from huggingface_hub import hf_hub_download
5
+ from langchain.llms import LlamaCpp, HuggingFacePipeline
6
 
7
  from transformers import (
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
10
  LlamaForCausalLM,
11
  LlamaTokenizer,
12
+ GenerationConfig,
13
+ pipeline,
14
+ TextStreamer
15
  )
16
+
17
+
18
+ torch.set_grad_enabled(False)
19
+
20
+
21
  from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
22
 
23
 
 
158
  )
159
  model.tie_weights()
160
  return model, tokenizer
161
+
162
+
163
+ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
164
+ """
165
+ Select a model for text generation using the HuggingFace library.
166
+ If you are running this for the first time, it will download a model for you.
167
+ subsequent runs will use the model from the disk.
168
+
169
+ Args:
170
+ device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU.
171
+ model_id (str): Identifier of the model to load from HuggingFace's model hub.
172
+ model_basename (str, optional): Basename of the model if using quantized models.
173
+ Defaults to None.
174
+
175
+ Returns:
176
+ HuggingFacePipeline: A pipeline object for text generation using the loaded model.
177
+
178
+ Raises:
179
+ ValueError: If an unsupported model or device type is provided.
180
+ """
181
+
182
+ logging.info(f"Loading Model: {model_id}, on: {device_type}")
183
+ logging.info("This action can take a few minutes!")
184
+
185
+ if model_basename is not None:
186
+ if ".gguf" in model_basename.lower():
187
+ llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
188
+ return llm
189
+ elif ".ggml" in model_basename.lower():
190
+ model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
191
+ else:
192
+ model, tokenizer = load_quantized_model_qptq(model_id, model_basename, device_type, LOGGING)
193
+ else:
194
+ model, tokenizer = load_full_model(model_id, model_basename, device_type, LOGGING)
195
+
196
+ # Load configuration from the model to avoid warnings
197
+ generation_config = GenerationConfig.from_pretrained(model_id)
198
+ # see here for details:
199
+ # https://huggingface.co/docs/transformers/
200
+ # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
201
+
202
+ # Create a pipeline for text generation
203
+
204
+
205
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
206
+
207
+ pipe = pipeline(
208
+ "text-generation",
209
+ model=model,
210
+ tokenizer=tokenizer,
211
+ max_length=50,
212
+ temperature=0.15,
213
+ top_p=0.1,
214
+ top_k=40,
215
+ repetition_penalty=1.0,
216
+ generation_config=generation_config,
217
+ streamer=streamer
218
+ )
219
+
220
+ local_llm = HuggingFacePipeline(pipeline=pipe)
221
+ logging.info("Local LLM Loaded")
222
+
223
+ return local_llm, streamer
main.py CHANGED
@@ -16,7 +16,7 @@ from langchain.memory import ConversationBufferMemory
16
 
17
 
18
  # from langchain.embeddings import HuggingFaceEmbeddings
19
- from run_localGPT import load_model
20
 
21
  # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.vectorstores import Chroma
@@ -31,7 +31,6 @@ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY,
31
  # DEVICE_TYPE = "cpu"
32
 
33
  DEVICE_TYPE = "cuda"
34
-
35
  SHOW_SOURCES = True
36
 
37
  EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
@@ -45,10 +44,10 @@ DB = Chroma(
45
 
46
  RETRIEVER = DB.as_retriever()
47
 
48
- LLM, STREAMER = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
 
49
 
50
- template = """you are a helpful, respectful and honest assistant.
51
- Your name is Katara llma. You should only use the source documents provided to answer the questions.
52
  You should only respond only topics that contains in documents use to training.
53
  Use the following pieces of context to answer the question at the end.
54
  Always answer in the most helpful and safe way possible.
 
16
 
17
 
18
  # from langchain.embeddings import HuggingFaceEmbeddings
19
+ from load_models import load_model
20
 
21
  # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.vectorstores import Chroma
 
31
  # DEVICE_TYPE = "cpu"
32
 
33
  DEVICE_TYPE = "cuda"
 
34
  SHOW_SOURCES = True
35
 
36
  EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
 
44
 
45
  RETRIEVER = DB.as_retriever()
46
 
47
+ models = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
48
+ LLM, STREAMER = models
49
 
50
+ template = """Your name is Katara and you are a helpful, respectful and honest assistant. You should only use the source documents provided to answer the questions.
 
51
  You should only respond only topics that contains in documents use to training.
52
  Use the following pieces of context to answer the question at the end.
53
  Always answer in the most helpful and safe way possible.
run_localGPT.py DELETED
@@ -1,273 +0,0 @@
1
- import os
2
- import logging
3
- import click
4
- import torch
5
- from langchain.chains import RetrievalQA
6
- from langchain.embeddings import HuggingFaceInstructEmbeddings
7
- from langchain.llms import HuggingFacePipeline
8
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
9
- from langchain.callbacks.manager import CallbackManager
10
-
11
- torch.set_grad_enabled(False)
12
-
13
- from prompt_template_utils import get_prompt_template
14
-
15
- from langchain.vectorstores import Chroma
16
- from transformers import (
17
- GenerationConfig,
18
- pipeline,
19
- TextStreamer
20
- )
21
-
22
- from load_models import (
23
- load_quantized_model_gguf_ggml,
24
- load_quantized_model_qptq,
25
- load_full_model,
26
- )
27
-
28
- from constants import (
29
- EMBEDDING_MODEL_NAME,
30
- PERSIST_DIRECTORY,
31
- MODEL_ID,
32
- MODEL_BASENAME,
33
- MAX_NEW_TOKENS,
34
- MODELS_PATH,
35
- )
36
-
37
-
38
- def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
39
- """
40
- Select a model for text generation using the HuggingFace library.
41
- If you are running this for the first time, it will download a model for you.
42
- subsequent runs will use the model from the disk.
43
-
44
- Args:
45
- device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU.
46
- model_id (str): Identifier of the model to load from HuggingFace's model hub.
47
- model_basename (str, optional): Basename of the model if using quantized models.
48
- Defaults to None.
49
-
50
- Returns:
51
- HuggingFacePipeline: A pipeline object for text generation using the loaded model.
52
-
53
- Raises:
54
- ValueError: If an unsupported model or device type is provided.
55
- """
56
-
57
- logging.info(f"Loading Model: {model_id}, on: {device_type}")
58
- logging.info("This action can take a few minutes!")
59
-
60
- if model_basename is not None:
61
- if ".gguf" in model_basename.lower():
62
- llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
63
- return llm
64
- elif ".ggml" in model_basename.lower():
65
- model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING)
66
- else:
67
- model, tokenizer = load_quantized_model_qptq(model_id, model_basename, device_type, LOGGING)
68
- else:
69
- model, tokenizer = load_full_model(model_id, model_basename, device_type, LOGGING)
70
-
71
- # Load configuration from the model to avoid warnings
72
- generation_config = GenerationConfig.from_pretrained(model_id)
73
- # see here for details:
74
- # https://huggingface.co/docs/transformers/
75
- # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
76
-
77
- # Create a pipeline for text generation
78
-
79
-
80
- streamer = TextStreamer(tokenizer, skip_prompt=True)
81
-
82
- pipe = pipeline(
83
- "text-generation",
84
- model=model,
85
- tokenizer=tokenizer,
86
- max_length=50,
87
- temperature=0.15,
88
- top_p=0.1,
89
- top_k=40,
90
- repetition_penalty=1.0,
91
- generation_config=generation_config,
92
- streamer=streamer
93
- )
94
-
95
- local_llm = HuggingFacePipeline(pipeline=pipe)
96
- logging.info("Local LLM Loaded")
97
-
98
- return local_llm, streamer
99
-
100
-
101
- def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
102
- """
103
- Initializes and returns a retrieval-based Question Answering (QA) pipeline.
104
-
105
- This function sets up a QA system that retrieves relevant information using embeddings
106
- from the HuggingFace library. It then answers questions based on the retrieved information.
107
-
108
- Parameters:
109
- - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'cuda', etc.
110
- - use_history (bool): Flag to determine whether to use chat history or not.
111
-
112
- Returns:
113
- - RetrievalQA: An initialized retrieval-based QA system.
114
-
115
- Notes:
116
- - The function uses embeddings from the HuggingFace library, either instruction-based or regular.
117
- - The Chroma class is used to load a vector store containing pre-computed embeddings.
118
- - The retriever fetches relevant documents or data based on a query.
119
- - The prompt and memory, obtained from the `get_prompt_template` function, might be used in the QA system.
120
- - The model is loaded onto the specified device using its ID and basename.
121
- - The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
122
- """
123
-
124
- embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
125
- # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
126
- # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
127
-
128
- # load the vectorstore
129
- db = Chroma(
130
- persist_directory=PERSIST_DIRECTORY,
131
- embedding_function=embeddings,
132
- )
133
- retriever = db.as_retriever()
134
-
135
- # get the prompt template and memory if set by the user.
136
- prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
137
-
138
- # load the llm pipeline
139
- llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
140
-
141
- if use_history:
142
- qa = RetrievalQA.from_chain_type(
143
- llm=llm,
144
- chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
145
- retriever=retriever,
146
- return_source_documents=True, # verbose=True,
147
- callbacks=callback_manager,
148
- chain_type_kwargs={"prompt": prompt, "memory": memory},
149
- )
150
- else:
151
- qa = RetrievalQA.from_chain_type(
152
- llm=llm,
153
- chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
154
- retriever=retriever,
155
- return_source_documents=True, # verbose=True,
156
- callbacks=callback_manager,
157
- chain_type_kwargs={
158
- "prompt": prompt,
159
- },
160
- )
161
-
162
- return qa
163
-
164
-
165
- # chose device typ to run on as well as to show source documents.
166
- @click.command()
167
- @click.option(
168
- "--device_type",
169
- default="cuda" if torch.cuda.is_available() else "cpu",
170
- type=click.Choice(
171
- [
172
- "cpu",
173
- "cuda",
174
- "ipu",
175
- "xpu",
176
- "mkldnn",
177
- "opengl",
178
- "opencl",
179
- "ideep",
180
- "hip",
181
- "ve",
182
- "fpga",
183
- "ort",
184
- "xla",
185
- "lazy",
186
- "vulkan",
187
- "mps",
188
- "meta",
189
- "hpu",
190
- "mtia",
191
- ],
192
- ),
193
- help="Device to run on. (Default is cuda)",
194
- )
195
- @click.option(
196
- "--show_sources",
197
- "-s",
198
- is_flag=True,
199
- help="Show sources along with answers (Default is False)",
200
- )
201
- @click.option(
202
- "--use_history",
203
- "-h",
204
- is_flag=True,
205
- help="Use history (Default is False)",
206
- )
207
- @click.option(
208
- "--model_type",
209
- default="llama",
210
- type=click.Choice(
211
- ["llama", "mistral", "non_llama"],
212
- ),
213
- help="model type, llama, mistral or non_llama",
214
- )
215
- def main(device_type, show_sources, use_history, model_type):
216
- """
217
- Implements the main information retrieval task for a localGPT.
218
-
219
- This function sets up the QA system by loading the necessary embeddings, vectorstore, and LLM model.
220
- It then enters an interactive loop where the user can input queries and receive answers. Optionally,
221
- the source documents used to derive the answers can also be displayed.
222
-
223
- Parameters:
224
- - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'mps', 'cuda', etc.
225
- - show_sources (bool): Flag to determine whether to display the source documents used for answering.
226
- - use_history (bool): Flag to determine whether to use chat history or not.
227
-
228
- Notes:
229
- - Logging information includes the device type, whether source documents are displayed, and the use of history.
230
- - If the models directory does not exist, it creates a new one to store models.
231
- - The user can exit the interactive loop by entering "exit".
232
- - The source documents are displayed if the show_sources flag is set to True.
233
-
234
- """
235
-
236
- logging.info(f"Running on: {device_type}")
237
- logging.info(f"Display Source Documents set to: {show_sources}")
238
- logging.info(f"Use history set to: {use_history}")
239
-
240
- # check if models directory do not exist, create a new one and store models here.
241
- if not os.path.exists(MODELS_PATH):
242
- os.mkdir(MODELS_PATH)
243
-
244
- qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type=model_type)
245
- # Interactive questions and answers
246
- while True:
247
- query = input("\nEnter a query: ")
248
- if query == "exit":
249
- break
250
- # Get the answer from the chain
251
- res = qa(query)
252
- answer, docs = res["result"], res["source_documents"]
253
-
254
- # Print the result
255
- print("\n\n> Question:")
256
- print(query)
257
- print("\n> Answer:")
258
- print(answer)
259
-
260
- if show_sources: # this is a flag that you can set to disable showing answers.
261
- # # Print the relevant sources used for the answer
262
- print("----------------------------------SOURCE DOCUMENTS---------------------------")
263
- for document in docs:
264
- print("\n> " + document.metadata["source"] + ":")
265
- print(document.page_content)
266
- print("----------------------------------SOURCE DOCUMENTS---------------------------")
267
-
268
-
269
- if __name__ == "__main__":
270
- logging.basicConfig(
271
- format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
272
- )
273
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_localGPT_API.py DELETED
@@ -1,184 +0,0 @@
1
- import logging
2
- import os
3
- import shutil
4
- import subprocess
5
-
6
- import torch
7
- from flask import Flask, jsonify, request
8
- from langchain.chains import RetrievalQA
9
- from langchain.embeddings import HuggingFaceInstructEmbeddings
10
-
11
- # from langchain.embeddings import HuggingFaceEmbeddings
12
- from run_localGPT import load_model
13
- from prompt_template_utils import get_prompt_template
14
-
15
- # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
16
- from langchain.vectorstores import Chroma
17
- from werkzeug.utils import secure_filename
18
-
19
- from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
20
-
21
- if torch.backends.mps.is_available():
22
- DEVICE_TYPE = "mps"
23
- elif torch.cuda.is_available():
24
- DEVICE_TYPE = "cuda"
25
- else:
26
- DEVICE_TYPE = "cpu"
27
-
28
- SHOW_SOURCES = True
29
- logging.info(f"Running on: {DEVICE_TYPE}")
30
- logging.info(f"Display Source Documents set to: {SHOW_SOURCES}")
31
-
32
- EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
33
-
34
- # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
35
- # EMBEDDINGS = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
36
- # if os.path.exists(PERSIST_DIRECTORY):
37
- # try:
38
- # shutil.rmtree(PERSIST_DIRECTORY)
39
- # except OSError as e:
40
- # print(f"Error: {e.filename} - {e.strerror}.")
41
- # else:
42
- # print("The directory does not exist")
43
-
44
- # run_langest_commands = ["python", "ingest.py"]
45
- # if DEVICE_TYPE == "cpu":
46
- # run_langest_commands.append("--device_type")
47
- # run_langest_commands.append(DEVICE_TYPE)
48
-
49
- # result = subprocess.run(run_langest_commands, capture_output=True)
50
- # if result.returncode != 0:
51
- # raise FileNotFoundError(
52
- # "No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!"
53
- # )
54
-
55
- # load the vectorstore
56
- DB = Chroma(
57
- persist_directory=PERSIST_DIRECTORY,
58
- embedding_function=EMBEDDINGS,
59
- client_settings=CHROMA_SETTINGS,
60
- )
61
-
62
- RETRIEVER = DB.as_retriever()
63
-
64
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
65
- prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
66
-
67
- QA = RetrievalQA.from_chain_type(
68
- llm=LLM,
69
- chain_type="stuff",
70
- retriever=RETRIEVER,
71
- return_source_documents=SHOW_SOURCES,
72
- chain_type_kwargs={
73
- "prompt": prompt,
74
- },
75
- )
76
-
77
- app = Flask(__name__)
78
-
79
-
80
- @app.route("/api/delete_source", methods=["GET"])
81
- def delete_source_route():
82
- folder_name = "SOURCE_DOCUMENTS"
83
-
84
- if os.path.exists(folder_name):
85
- shutil.rmtree(folder_name)
86
-
87
- os.makedirs(folder_name)
88
-
89
- return jsonify({"message": f"Folder '{folder_name}' successfully deleted and recreated."})
90
-
91
-
92
- @app.route("/api/save_document", methods=["GET", "POST"])
93
- def save_document_route():
94
- if "document" not in request.files:
95
- return "No document part", 400
96
- file = request.files["document"]
97
- if file.filename == "":
98
- return "No selected file", 400
99
- if file:
100
- filename = secure_filename(file.filename)
101
- folder_path = "SOURCE_DOCUMENTS"
102
- if not os.path.exists(folder_path):
103
- os.makedirs(folder_path)
104
- file_path = os.path.join(folder_path, filename)
105
- file.save(file_path)
106
- return "File saved successfully", 200
107
-
108
-
109
- @app.route("/api/run_ingest", methods=["GET"])
110
- def run_ingest_route():
111
- global DB
112
- global RETRIEVER
113
- global QA
114
- try:
115
- if os.path.exists(PERSIST_DIRECTORY):
116
- try:
117
- shutil.rmtree(PERSIST_DIRECTORY)
118
- except OSError as e:
119
- print(f"Error: {e.filename} - {e.strerror}.")
120
- else:
121
- print("The directory does not exist")
122
-
123
- run_langest_commands = ["python", "ingest.py"]
124
- if DEVICE_TYPE == "cpu":
125
- run_langest_commands.append("--device_type")
126
- run_langest_commands.append(DEVICE_TYPE)
127
-
128
- result = subprocess.run(run_langest_commands, capture_output=True)
129
- if result.returncode != 0:
130
- return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500
131
- # load the vectorstore
132
- DB = Chroma(
133
- persist_directory=PERSIST_DIRECTORY,
134
- embedding_function=EMBEDDINGS,
135
- client_settings=CHROMA_SETTINGS,
136
- )
137
- RETRIEVER = DB.as_retriever()
138
- prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
139
-
140
- QA = RetrievalQA.from_chain_type(
141
- llm=LLM,
142
- chain_type="stuff",
143
- retriever=RETRIEVER,
144
- return_source_documents=SHOW_SOURCES,
145
- chain_type_kwargs={
146
- "prompt": prompt,
147
- },
148
- )
149
- return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200
150
- except Exception as e:
151
- return f"Error occurred: {str(e)}", 500
152
-
153
-
154
- @app.route("/api/prompt_route", methods=["GET", "POST"])
155
- def prompt_route():
156
- global QA
157
- user_prompt = request.form.get("user_prompt")
158
- if user_prompt:
159
- # print(f'User Prompt: {user_prompt}')
160
- # Get the answer from the chain
161
- res = QA(user_prompt)
162
- answer, docs = res["result"], res["source_documents"]
163
-
164
- prompt_response_dict = {
165
- "Prompt": user_prompt,
166
- "Answer": answer,
167
- }
168
-
169
- prompt_response_dict["Sources"] = []
170
- for document in docs:
171
- prompt_response_dict["Sources"].append(
172
- (os.path.basename(str(document.metadata["source"])), str(document.page_content))
173
- )
174
-
175
- return jsonify(prompt_response_dict), 200
176
- else:
177
- return "No user prompt received", 400
178
-
179
-
180
- if __name__ == "__main__":
181
- logging.basicConfig(
182
- format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
183
- )
184
- app.run(debug=False, port=5110)