Sonny4Sonnix's picture
Create main.py
f349339
raw
history blame
1.87 kB
# main.py
from fastapi import FastAPI, Query, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
import xgboost as xgb
import joblib
import pandas as pd
from pydantic import BaseModel # Import Pydantic's BaseModel
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Load the pickled XGBoost model
xgb_model = joblib.load("xgb_model.joblib")
class InputFeatures(BaseModel):
prg: float
pl: float
pr: float
sk: float
ts: float
m11: float
bd2: float
age: int
@app.get("/")
async def read_root():
return {"message": "Welcome to the XGBoost Diabetes Prediction API"}
@app.get("/form/")
async def show_form():
return templates.TemplateResponse("input_form.html", {"request": None})
@app.post("/predict/")
async def predict_diabetes(
request: Request,
input_data: InputFeatures # Use the Pydantic model for input validation
):
try:
# Convert Pydantic model to a DataFrame for prediction
input_df = pd.DataFrame([input_data.dict()])
# Make predictions using the loaded XGBoost model
prediction = xgb_model.predict_proba(xgb.DMatrix(input_df))
# Create a JSON response
response = {
"input_features": input_data,
"prediction": {
"class_0_probability": prediction[0],
"class_1_probability": prediction[1]
}
}
return templates.TemplateResponse(
"display_params.html",
{
"request": request,
"input_features": response["input_features"],
"prediction": response["prediction"]
}
)
except Exception as e:
raise HTTPException(status_code=500, detail="An error occurred while processing the request.")