in_pr_test / api.py
manivannan023's picture
Update api.py
694d708 verified
from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
import pandas as pd
from data_validation.data_validation import DataValidation
from data_cleaning.data_cleaning import DataCleaning
from model_inference.model_inference import predict
import traceback
import gc
from utils.load_model import load_model
import logging
app = FastAPI()
loaded_model = load_model(load_deployed_model=True, model_file_name=None)
class Data(BaseModel):
"""
Data dictionary for data type validation
"""
age: int
workclass: str
fnlwgt: int
education: str
education_num: int
marital_status: str
occupation: str
relationship: str
race: str
sex: str
capital_gain: int
capital_loss: int
hours_per_week: int
country: str
@app.post("/")
def prediction(data: Data):
"""
Processes the API request and returns a prediction
"""
logging.warning("entering prediction_api")
try:
df = pd.DataFrame(data.dict(), index=[0]) # converting api data dict to df
# dv = DataValidation(input_df=df, dataset="prediction") # validating the data
# validation_status = dv.validate_data() # status of validation. 1=passed, 0=failed
# validation_status = 1
# if validation_status != 0:
data_cleaning = DataCleaning()
# cleaning the data
df = data_cleaning.clean_column_names(df).copy()
df = data_cleaning.shorten_column_names(df).copy()
df = data_cleaning.clean_nan(df).copy()
# calling the 'model_inference.model_inference.predict' function
pred = predict(loaded_model, df, predict_proba=False, predict_label=True)[0].strip()
# else:
# # executes when data validation fails
# pred = "data validation failed"
except Exception as e:
# executes in case of any exception
pred = e
logging.warning(f"unexpected error in prediction_api: {traceback.format_exc()}")
raise
gc.collect()
logging.warning("exiting prediction_api")
return {"result": pred}
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)