from fastapi import FastAPI import uvicorn from pydantic import BaseModel import pandas as pd from data_validation.data_validation import DataValidation from data_cleaning.data_cleaning import DataCleaning from model_inference.model_inference import predict import traceback import gc from utils.load_model import load_model import logging app = FastAPI() loaded_model = load_model(load_deployed_model=True, model_file_name=None) class Data(BaseModel): """ Data dictionary for data type validation """ age: int workclass: str fnlwgt: int education: str education_num: int marital_status: str occupation: str relationship: str race: str sex: str capital_gain: int capital_loss: int hours_per_week: int country: str @app.post("/") def prediction(data: Data): """ Processes the API request and returns a prediction """ logging.warning("entering prediction_api") try: df = pd.DataFrame(data.dict(), index=[0]) # converting api data dict to df # dv = DataValidation(input_df=df, dataset="prediction") # validating the data # validation_status = dv.validate_data() # status of validation. 1=passed, 0=failed # validation_status = 1 # if validation_status != 0: data_cleaning = DataCleaning() # cleaning the data df = data_cleaning.clean_column_names(df).copy() df = data_cleaning.shorten_column_names(df).copy() df = data_cleaning.clean_nan(df).copy() # calling the 'model_inference.model_inference.predict' function pred = predict(loaded_model, df, predict_proba=False, predict_label=True)[0].strip() # else: # # executes when data validation fails # pred = "data validation failed" except Exception as e: # executes in case of any exception pred = e logging.warning(f"unexpected error in prediction_api: {traceback.format_exc()}") raise gc.collect() logging.warning("exiting prediction_api") return {"result": pred} if __name__ == '__main__': uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)