|
|
import os |
|
|
import pickle |
|
|
import requests |
|
|
from dotenv import load_dotenv |
|
|
from flask import Flask, render_template, request |
|
|
from pydantic import BaseModel, Field, PositiveFloat |
|
|
|
|
|
app = Flask(__name__) |
|
|
MODEL = pickle.load(open("model.sav", "rb")) |
|
|
|
|
|
PRICE_BASE = 10**5 |
|
|
|
|
|
load_dotenv(override=False) |
|
|
|
|
|
API_HOST = os.environ.get("API_HOST", "localhost") |
|
|
|
|
|
|
|
|
class FormQuery(BaseModel): |
|
|
med_inc: PositiveFloat = Field(..., validation_alias="MedInc") |
|
|
house_age: PositiveFloat = Field(..., validation_alias="HouseAge") |
|
|
ave_rooms: PositiveFloat = Field(..., validation_alias="AveRooms") |
|
|
ave_bedrms: PositiveFloat = Field(..., validation_alias="AveBedrms") |
|
|
population: PositiveFloat = Field(..., validation_alias="Population") |
|
|
ave_occup: PositiveFloat = Field(..., validation_alias="AveOccup") |
|
|
latitude: float = Field(..., validation_alias="Latitude") |
|
|
longitude: float = Field(..., validation_alias="Longitude") |
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def california_index(): |
|
|
return render_template("index.html") |
|
|
|
|
|
|
|
|
@app.route("/predict/", methods=["POST"]) |
|
|
def local_model_result(): |
|
|
form_query = FormQuery(**request.form.to_dict(flat=True)) |
|
|
|
|
|
reg = MODEL.predict( |
|
|
[ |
|
|
[ |
|
|
form_query.med_inc, |
|
|
form_query.house_age, |
|
|
form_query.ave_rooms, |
|
|
form_query.ave_bedrms, |
|
|
form_query.population, |
|
|
form_query.ave_occup, |
|
|
form_query.latitude, |
|
|
form_query.longitude, |
|
|
] |
|
|
] |
|
|
)[0] |
|
|
return render_template("prediction.html", price=reg * PRICE_BASE) |
|
|
|
|
|
|
|
|
@app.route("/predict_from_api/", methods=["POST"]) |
|
|
def api_result(): |
|
|
model_list = requests.get(f"http://{API_HOST}:8000/model/list/").json() |
|
|
if len(model_list) == 0: |
|
|
raise Exception("No model could be retrieved from the model registry") |
|
|
|
|
|
best_model = sorted(model_list, key=lambda d: d["rse"])[0] |
|
|
app.logger.debug(f"Best model retrieved : {best_model}") |
|
|
|
|
|
api_response = requests.post( |
|
|
f"http://{API_HOST}:8000/model/predict/", |
|
|
json={ |
|
|
**{"train_id": best_model["train_id"]}, |
|
|
**FormQuery(**request.form.to_dict(flat=True)).model_dump(), |
|
|
}, |
|
|
) |
|
|
|
|
|
response = api_response.json() |
|
|
app.logger.debug(response) |
|
|
|
|
|
return render_template("prediction.html", price=response["reg"] * PRICE_BASE) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.debug = True |
|
|
app.run(host="0.0.0.0", port=5000, debug=True) |
|
|
|