california-app / app /main.py
Figea's picture
first commit
f1c5423
raw
history blame contribute delete
No virus
2.5 kB
import os
import pickle
import requests
from dotenv import load_dotenv
from flask import Flask, render_template, request
from pydantic import BaseModel, Field, PositiveFloat
app = Flask(__name__)
MODEL = pickle.load(open("model.sav", "rb"))
PRICE_BASE = 10**5
load_dotenv(override=False)
API_HOST = os.environ.get("API_HOST", "localhost")
class FormQuery(BaseModel):
med_inc: PositiveFloat = Field(..., validation_alias="MedInc")
house_age: PositiveFloat = Field(..., validation_alias="HouseAge")
ave_rooms: PositiveFloat = Field(..., validation_alias="AveRooms")
ave_bedrms: PositiveFloat = Field(..., validation_alias="AveBedrms")
population: PositiveFloat = Field(..., validation_alias="Population")
ave_occup: PositiveFloat = Field(..., validation_alias="AveOccup")
latitude: float = Field(..., validation_alias="Latitude")
longitude: float = Field(..., validation_alias="Longitude")
@app.route("/", methods=["GET"])
def california_index():
return render_template("index.html")
@app.route("/predict/", methods=["POST"])
def local_model_result():
form_query = FormQuery(**request.form.to_dict(flat=True))
reg = MODEL.predict(
[
[
form_query.med_inc,
form_query.house_age,
form_query.ave_rooms,
form_query.ave_bedrms,
form_query.population,
form_query.ave_occup,
form_query.latitude,
form_query.longitude,
]
]
)[0]
return render_template("prediction.html", price=reg * PRICE_BASE)
@app.route("/predict_from_api/", methods=["POST"])
def api_result():
model_list = requests.get(f"http://{API_HOST}:8000/model/list/").json()
if len(model_list) == 0:
raise Exception("No model could be retrieved from the model registry")
best_model = sorted(model_list, key=lambda d: d["mse"], reverse=True)[0]
app.logger.debug(f"Best model retrieved : {best_model}")
api_response = requests.post(
f"http://{API_HOST}:8000/model/predict/",
json={
**{"train_id": best_model["train_id"]},
**FormQuery(**request.form.to_dict(flat=True)).model_dump(),
},
)
response = api_response.json()
app.logger.debug(response)
return render_template("prediction.html", price=response["reg"] * PRICE_BASE)
if __name__ == "__main__":
app.debug = True
app.run(host="0.0.0.0", port=5000, debug=True)