churn-model / handler.py
philschmid's picture
philschmid HF staff
Update handler.py
e01039c
raw history blame
No virus
1.06 kB
import pandas as pd
import pickle
from typing import Dict, List, Any
import numpy as np
import os
# set device
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
pathb = os.path.join(path,"./churn.pkl")
self.pipe = pd.read_pickle(pathb)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. A string representing what the label/class is
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
df = pd.DataFrame(inputs)
df["TotalCharges"] = df["TotalCharges"].replace(" ", np.nan, regex=False).astype(float)
df = df.drop(columns=["customerID"])
df = df.drop(columns=["Churn"])
# run inference pipeline
pred = self.pipe.predict(df)
# postprocess the prediction
return {"pred": pred}