File size: 2,182 Bytes
4a58702
 
 
 
 
 
 
 
ea17396
ad49b5e
4f23d5f
625d083
4a58702
 
ad49b5e
4a58702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f23d5f
4a58702
4f23d5f
e563596
 
625d083
 
694d708
 
 
 
 
4a58702
4f23d5f
4a58702
625d083
 
 
4a58702
 
 
 
4f23d5f
4a58702
ea17396
4f23d5f
4a58702
 
 
 
ea17396
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
import pandas as pd
from data_validation.data_validation import DataValidation
from data_cleaning.data_cleaning import DataCleaning
from model_inference.model_inference import predict
import traceback
import gc
from utils.load_model import load_model
import logging


app = FastAPI()
loaded_model = load_model(load_deployed_model=True, model_file_name=None)

class Data(BaseModel):
    """
    Data dictionary for data type validation
    """
    age: int
    workclass: str
    fnlwgt: int
    education: str
    education_num: int
    marital_status: str
    occupation: str
    relationship: str
    race: str
    sex: str
    capital_gain: int
    capital_loss: int
    hours_per_week: int
    country: str


@app.post("/")
def prediction(data: Data):
    """
    Processes the API request and returns a prediction
    """
    logging.warning("entering prediction_api")
    try:
        df = pd.DataFrame(data.dict(), index=[0])  # converting api data dict to df
        # dv = DataValidation(input_df=df, dataset="prediction")  # validating the data
        # validation_status = dv.validate_data()  # status of validation. 1=passed, 0=failed
        # validation_status = 1
        # if validation_status != 0:
        data_cleaning = DataCleaning()
        # cleaning the data
        df = data_cleaning.clean_column_names(df).copy()
        df = data_cleaning.shorten_column_names(df).copy()
        df = data_cleaning.clean_nan(df).copy()
            # calling the 'model_inference.model_inference.predict' function
        pred = predict(loaded_model, df, predict_proba=False, predict_label=True)[0].strip()

        # else:
        #     # executes when data validation fails
        #     pred = "data validation failed"

    except Exception as e:
        # executes in case of any exception
        pred = e
        logging.warning(f"unexpected error in prediction_api: {traceback.format_exc()}")
        raise
    gc.collect()
    logging.warning("exiting prediction_api")
    return {"result": pred}


if __name__ == '__main__':
    uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)