abhishekrs4 commited on
Commit
6e19446
1 Parent(s): 71bf83e

updated fastapi app script

Browse files
Files changed (1) hide show
  1. app.py +42 -7
app.py CHANGED
@@ -10,15 +10,16 @@ from pydantic import BaseModel
10
  from config import settings
11
 
12
  try:
13
- path_mlflow_model = "trained_models/knn_random_forest"
14
  sklearn_pipeline = mlflow.sklearn.load_model(path_mlflow_model)
15
  except:
16
- path_mlflow_model = "/data/models/knn_random_forest"
17
  sklearn_pipeline = mlflow.sklearn.load_model(path_mlflow_model)
18
 
19
  app = FastAPI()
20
  logging.basicConfig(level=logging.INFO)
21
 
 
22
  class WaterPotabilityDataItem(BaseModel):
23
  ph: Union[float, None] = np.nan
24
  Hardness: Union[float, None] = np.nan
@@ -30,20 +31,53 @@ class WaterPotabilityDataItem(BaseModel):
30
  Trihalomethanes: Union[float, None] = np.nan
31
  Turbidity: Union[float, None] = np.nan
32
 
 
33
  def predict_pipeline(data_sample):
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  pred_sample = sklearn_pipeline.predict(data_sample)
35
  return pred_sample
36
 
 
37
  @app.get("/info")
38
  def get_app_info():
39
- dict_info = {
40
- "app_name": settings.app_name,
41
- "version": settings.version
42
- }
 
 
 
 
43
  return dict_info
44
 
 
45
  @app.post("/predict")
46
  def predict(wpd_item: WaterPotabilityDataItem):
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  wpd_arr = np.array(
48
  [
49
  wpd_item.ph,
@@ -60,4 +94,5 @@ def predict(wpd_item: WaterPotabilityDataItem):
60
  logging.info("data sample: %s", wpd_arr)
61
  pred_sample = predict_pipeline(wpd_arr)
62
  logging.info("Potability prediction: %s", pred_sample)
63
- return {"Potability": int(pred_sample)}
 
 
10
  from config import settings
11
 
12
  try:
13
+ path_mlflow_model = "./model_for_production/"
14
  sklearn_pipeline = mlflow.sklearn.load_model(path_mlflow_model)
15
  except:
16
+ path_mlflow_model = "/data/model_for_production/"
17
  sklearn_pipeline = mlflow.sklearn.load_model(path_mlflow_model)
18
 
19
  app = FastAPI()
20
  logging.basicConfig(level=logging.INFO)
21
 
22
+
23
  class WaterPotabilityDataItem(BaseModel):
24
  ph: Union[float, None] = np.nan
25
  Hardness: Union[float, None] = np.nan
 
31
  Trihalomethanes: Union[float, None] = np.nan
32
  Turbidity: Union[float, None] = np.nan
33
 
34
+
35
  def predict_pipeline(data_sample):
36
+ """
37
+ ---------
38
+ Arguments
39
+ ---------
40
+ data_sample : np.array
41
+ a numpy array of shape (num_samples, num_feats)
42
+
43
+ -------
44
+ Returns
45
+ -------
46
+ pred_sample : np.array
47
+ a numpy array of shape (num_samples) with predictions
48
+ """
49
  pred_sample = sklearn_pipeline.predict(data_sample)
50
  return pred_sample
51
 
52
+
53
  @app.get("/info")
54
  def get_app_info():
55
+ """
56
+ -------
57
+ Returns
58
+ -------
59
+ dict_info : dict
60
+ a dictionary with info to be sent as a response to get request
61
+ """
62
+ dict_info = {"app_name": settings.app_name, "version": settings.version}
63
  return dict_info
64
 
65
+
66
  @app.post("/predict")
67
  def predict(wpd_item: WaterPotabilityDataItem):
68
+ """
69
+ ---------
70
+ Arguments
71
+ ---------
72
+ wpd_item : object
73
+ an object of type WaterPotabilityDataItem
74
+
75
+ -------
76
+ Returns
77
+ -------
78
+ pred_dict : dict
79
+ a dictionary of prediction to be sent as a response to post request
80
+ """
81
  wpd_arr = np.array(
82
  [
83
  wpd_item.ph,
 
94
  logging.info("data sample: %s", wpd_arr)
95
  pred_sample = predict_pipeline(wpd_arr)
96
  logging.info("Potability prediction: %s", pred_sample)
97
+ pred_dict = {"Potability": int(pred_sample)}
98
+ return pred_dict