WissMah's picture
Update stroke-flask-docker/app.py
6d2df9a verified
from flask import Flask, render_template, request, jsonify
import joblib
import numpy as np
import os
import pandas as pd
APP_PORT = int(os.getenv("PORT", "8080"))
app = Flask(__name__)
MODEL_PATH = os.getenv("MODEL_PATH", "model/stroke_pipeline.joblib")
# Load model pipeline at startup
try:
pipeline = joblib.load(MODEL_PATH)
except Exception as e:
raise RuntimeError(f"Failed to load model at {MODEL_PATH}: {e}")
FEATURE_ORDER = [
"gender",
"age",
"hypertension",
"heart_disease",
"ever_married",
"work_type",
"Residence_type",
"avg_glucose_level",
"bmi",
"smoking_status",
]
# Simple healthcheck
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"}), 200
@app.route("/", methods=["GET"])
def index():
# Provide default values to make testing easy
defaults = {
"gender": "Female",
"age": 45,
"hypertension": 0,
"heart_disease": 0,
"ever_married": "Yes",
"work_type": "Private",
"Residence_type": "Urban",
"avg_glucose_level": 95.0,
"bmi": 28.0,
"smoking_status": "never smoked",
}
return render_template("index.html", defaults=defaults)
@app.route("/predict", methods=["POST"])
def predict():
try:
payload = request.get_json() if request.is_json else request.form.to_dict()
# Normalize types
numeric_fields = ["age", "avg_glucose_level", "bmi"]
int_fields = ["hypertension", "heart_disease"]
for k in numeric_fields:
if k in payload:
payload[k] = float(payload[k])
for k in int_fields:
if k in payload:
payload[k] = int(payload[k])
# ALWAYS send a DataFrame with named columns
X = pd.DataFrame([{f: payload.get(f, None) for f in FEATURE_ORDER}])[FEATURE_ORDER]
prob = float(pipeline.predict_proba(X)[0][1])
# quick sanity log (optional)
# print("X type:", type(X), "cols:", list(X.columns))
pred = int(prob >= 0.3)
result = {"stroke_probability": prob, "predicted_label": pred}
return jsonify(result) if request.is_json else render_template("index.html", result=result, defaults=payload)
except Exception as e:
return (jsonify({"error": str(e)}), 400) if request.is_json else \
(render_template("index.html", error=str(e), defaults=payload), 400)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=APP_PORT, debug=False)