dkdaniz commited on
Commit
0770449
1 Parent(s): b34b7d7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +83 -5
main.py CHANGED
@@ -1,16 +1,94 @@
1
  from fastapi import FastAPI
2
  import pickle
3
  import uvicorn
4
- import pandas as pd
5
 
6
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @app.get("/")
10
  def root():
11
  return {"API": "An API for Sepsis Prediction."}
12
 
13
 
14
- @app.get('/Predict')
15
- async def predict():
16
- return 'output_pred = "Sepsis status is Negative"'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  import pickle
3
  import uvicorn
 
4
 
5
+ import logging
6
+ import os
7
+ import shutil
8
+ import subprocess
9
+
10
+ import torch
11
+ from flask import Flask, jsonify, request, render_template
12
+ from langchain.chains import RetrievalQA
13
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
14
+
15
+ # from langchain.embeddings import HuggingFaceEmbeddings
16
+ from run_localGPT import load_model
17
+ from prompt_template_utils import get_prompt_template
18
+
19
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
20
+ from langchain.vectorstores import Chroma
21
+ from werkzeug.utils import secure_filename
22
+
23
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
24
+
25
+ if torch.backends.mps.is_available():
26
+ DEVICE_TYPE = "mps"
27
+ elif torch.cuda.is_available():
28
+ DEVICE_TYPE = "cuda"
29
+ else:
30
+ DEVICE_TYPE = "cpu"
31
+
32
+ SHOW_SOURCES = True
33
+ logging.info(f"Running on: {DEVICE_TYPE}")
34
+ logging.info(f"Display Source Documents set to: {SHOW_SOURCES}")
35
+
36
+ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
37
+
38
+ # load the vectorstore
39
+ DB = Chroma(
40
+ persist_directory=PERSIST_DIRECTORY,
41
+ embedding_function=EMBEDDINGS,
42
+ client_settings=CHROMA_SETTINGS,
43
+ )
44
 
45
+ RETRIEVER = DB.as_retriever()
46
+
47
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
48
+ prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False)
49
+
50
+ QA = RetrievalQA.from_chain_type(
51
+ llm=LLM,
52
+ chain_type="stuff",
53
+ retriever=RETRIEVER,
54
+ return_source_documents=SHOW_SOURCES,
55
+ chain_type_kwargs={
56
+ "prompt": prompt,
57
+ },
58
+ )
59
+
60
+ class Predict(BaseModel):
61
+ prompt: str
62
+
63
+
64
+ app = FastAPI()
65
 
66
  @app.get("/")
67
  def root():
68
  return {"API": "An API for Sepsis Prediction."}
69
 
70
 
71
+ @app.post('/predict')
72
+ async def predict(data: Predict):
73
+ global QA
74
+ user_prompt = data.prompt
75
+ if user_prompt:
76
+ # print(f'User Prompt: {user_prompt}')
77
+ # Get the answer from the chain
78
+ res = QA(user_prompt)
79
+ answer, docs = res["result"], res["source_documents"]
80
+
81
+ prompt_response_dict = {
82
+ "Prompt": user_prompt,
83
+ "Answer": answer,
84
+ }
85
+
86
+ prompt_response_dict["Sources"] = []
87
+ for document in docs:
88
+ prompt_response_dict["Sources"].append(
89
+ (os.path.basename(str(document.metadata["source"])), str(document.page_content))
90
+ )
91
+
92
+ return jsonify(prompt_response_dict), 200
93
+ else:
94
+ return "No user prompt received", 400