Arafath10's picture
Update main.py
b6cd364 verified
raw
history blame
No virus
2.19 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pandas as pd
import numpy as np
import joblib
# Load your trained model and encoders
xgb_model = joblib.load("model/transexpress_xgb_model.joblib")
encoders = joblib.load("model/transexpress_encoders.joblib")
# Function to handle unseen labels during encoding
def safe_transform(encoder, column):
classes = encoder.classes_
return [encoder.transform([x])[0] if x in classes else -1 for x in column]
# Define FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Endpoint for making predictions
@app.post("/predict")
def predict(
customer_name: str,
customer_address: str,
customer_phone: str,
customer_email: str,
weight: str,
cod:str,
pickup_address: str,
destination_city_name: str):
# Convert input data to DataFrame
if destination_city_name=="":
destination_city_name = 'Missing'
input_data = {
'customer_name': customer_name,
'customer_address': customer_address,
'customer_phone': customer_phone,
'customer_email': customer_email,
'cod': float(cod),
'weight': float(weight),
'pickup_address':pickup_address,
'destination_city.name':destination_city_name
}
input_df = pd.DataFrame([input_data])
# Encode categorical variables using the same encoders used during training
for col in input_df.columns:
if col in encoders:
input_df[col] = safe_transform(encoders[col], input_df[col])
# Predict and obtain probabilities
pred = xgb_model.predict(input_df)
pred_proba = xgb_model.predict_proba(input_df)
# Output
predicted_status = "Unknown" if pred[0] == -1 else encoders['status.name'].inverse_transform([pred])[0]
probability = pred_proba[0][pred[0]] * 100 if pred[0] != -1 else "Unknown"
if predicted_status == "Returned to Client":
probability = 100 - probability
return {"Probability": round(probability,2)}