import os import pickle import requests from dotenv import load_dotenv from flask import Flask, render_template, request from pydantic import BaseModel, Field, PositiveFloat app = Flask(__name__) MODEL = pickle.load(open("model.sav", "rb")) PRICE_BASE = 10**5 load_dotenv(override=False) API_HOST = os.environ.get("API_HOST", "localhost") class FormQuery(BaseModel): med_inc: PositiveFloat = Field(..., validation_alias="MedInc") house_age: PositiveFloat = Field(..., validation_alias="HouseAge") ave_rooms: PositiveFloat = Field(..., validation_alias="AveRooms") ave_bedrms: PositiveFloat = Field(..., validation_alias="AveBedrms") population: PositiveFloat = Field(..., validation_alias="Population") ave_occup: PositiveFloat = Field(..., validation_alias="AveOccup") latitude: float = Field(..., validation_alias="Latitude") longitude: float = Field(..., validation_alias="Longitude") @app.route("/", methods=["GET"]) def california_index(): return render_template("index.html") @app.route("/predict/", methods=["POST"]) def local_model_result(): form_query = FormQuery(**request.form.to_dict(flat=True)) reg = MODEL.predict( [ [ form_query.med_inc, form_query.house_age, form_query.ave_rooms, form_query.ave_bedrms, form_query.population, form_query.ave_occup, form_query.latitude, form_query.longitude, ] ] )[0] return render_template("prediction.html", price=reg * PRICE_BASE) @app.route("/predict_from_api/", methods=["POST"]) def api_result(): model_list = requests.get(f"http://{API_HOST}:8000/model/list/").json() if len(model_list) == 0: raise Exception("No model could be retrieved from the model registry") best_model = sorted(model_list, key=lambda d: d["mse"], reverse=True)[0] app.logger.debug(f"Best model retrieved : {best_model}") api_response = requests.post( f"http://{API_HOST}:8000/model/predict/", json={ **{"train_id": best_model["train_id"]}, **FormQuery(**request.form.to_dict(flat=True)).model_dump(), }, ) response = api_response.json() app.logger.debug(response) return render_template("prediction.html", price=response["reg"] * PRICE_BASE) if __name__ == "__main__": app.debug = True app.run(host="0.0.0.0", port=5000, debug=True)