Spaces:
Build error
Build error
Daniel Marques
commited on
Commit
·
70e00d3
1
Parent(s):
36d73c0
feat: v1
Browse files- constants.py +2 -2
- main.py +22 -20
constants.py
CHANGED
@@ -32,12 +32,12 @@ CHROMA_SETTINGS = Settings(
|
|
32 |
)
|
33 |
|
34 |
# Context Window and Max New Tokens
|
35 |
-
CONTEXT_WINDOW_SIZE =
|
36 |
MAX_NEW_TOKENS = CONTEXT_WINDOW_SIZE # int(CONTEXT_WINDOW_SIZE/4)
|
37 |
|
38 |
#### If you get a "not enough space in the buffer" error, you should reduce the values below, start with half of the original values and keep halving the value until the error stops appearing
|
39 |
|
40 |
-
N_GPU_LAYERS =
|
41 |
N_BATCH = 2048
|
42 |
|
43 |
### From experimenting with the Llama-2-7B-Chat-GGML model on 8GB VRAM, these values work:
|
|
|
32 |
)
|
33 |
|
34 |
# Context Window and Max New Tokens
|
35 |
+
CONTEXT_WINDOW_SIZE = 2048
|
36 |
MAX_NEW_TOKENS = CONTEXT_WINDOW_SIZE # int(CONTEXT_WINDOW_SIZE/4)
|
37 |
|
38 |
#### If you get a "not enough space in the buffer" error, you should reduce the values below, start with half of the original values and keep halving the value until the error stops appearing
|
39 |
|
40 |
+
N_GPU_LAYERS = 83 # Llama-2-70B has 83 layers
|
41 |
N_BATCH = 2048
|
42 |
|
43 |
### From experimenting with the Llama-2-7B-Chat-GGML model on 8GB VRAM, these values work:
|
main.py
CHANGED
@@ -136,28 +136,33 @@ def delete_source_route(data: Delete):
|
|
136 |
raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist."))
|
137 |
|
138 |
@api_app.post('/predict')
|
139 |
-
|
140 |
global QA
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
144 |
|
145 |
-
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
prompt_response_dict["Sources"] = []
|
153 |
-
for document in docs:
|
154 |
-
prompt_response_dict["Sources"].append(
|
155 |
-
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
156 |
-
)
|
157 |
|
158 |
-
return {"response": prompt_response_dict}
|
159 |
-
else:
|
160 |
-
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
161 |
|
162 |
@api_app.post("/save_document/")
|
163 |
async def create_upload_file(file: UploadFile):
|
@@ -208,12 +213,9 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
|
208 |
try:
|
209 |
while True:
|
210 |
prompt = await websocket.receive_text()
|
211 |
-
|
212 |
response = QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
|
213 |
-
|
214 |
await websocket.send_text(f'{response}')
|
215 |
|
216 |
-
|
217 |
except WebSocketDisconnect:
|
218 |
print('disconnect')
|
219 |
except RuntimeError as error:
|
|
|
136 |
raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist."))
|
137 |
|
138 |
@api_app.post('/predict')
|
139 |
+
def predict(data: Predict):
|
140 |
global QA
|
141 |
+
try:
|
142 |
+
user_prompt = data.prompt
|
143 |
+
if user_prompt:
|
144 |
+
res = QA(user_prompt)
|
145 |
|
146 |
+
answer, docs = res["result"], res["source_documents"]
|
147 |
|
148 |
+
prompt_response_dict = {
|
149 |
+
"Prompt": user_prompt,
|
150 |
+
"Answer": answer,
|
151 |
+
}
|
152 |
+
|
153 |
+
prompt_response_dict["Sources"] = []
|
154 |
+
for document in docs:
|
155 |
+
prompt_response_dict["Sources"].append(
|
156 |
+
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
157 |
+
)
|
158 |
+
|
159 |
+
return {"response": prompt_response_dict}
|
160 |
+
else:
|
161 |
+
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
162 |
+
except Exception as e:
|
163 |
+
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
|
164 |
|
|
|
|
|
|
|
|
|
|
|
165 |
|
|
|
|
|
|
|
166 |
|
167 |
@api_app.post("/save_document/")
|
168 |
async def create_upload_file(file: UploadFile):
|
|
|
213 |
try:
|
214 |
while True:
|
215 |
prompt = await websocket.receive_text()
|
|
|
216 |
response = QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
|
|
|
217 |
await websocket.send_text(f'{response}')
|
218 |
|
|
|
219 |
except WebSocketDisconnect:
|
220 |
print('disconnect')
|
221 |
except RuntimeError as error:
|