File size: 2,502 Bytes
f1c5423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import pickle

import requests
from dotenv import load_dotenv
from flask import Flask, render_template, request
from pydantic import BaseModel, Field, PositiveFloat

app = Flask(__name__)
MODEL = pickle.load(open("model.sav", "rb"))

PRICE_BASE = 10**5

load_dotenv(override=False)

API_HOST = os.environ.get("API_HOST", "localhost")


class FormQuery(BaseModel):
    med_inc: PositiveFloat = Field(..., validation_alias="MedInc")
    house_age: PositiveFloat = Field(..., validation_alias="HouseAge")
    ave_rooms: PositiveFloat = Field(..., validation_alias="AveRooms")
    ave_bedrms: PositiveFloat = Field(..., validation_alias="AveBedrms")
    population: PositiveFloat = Field(..., validation_alias="Population")
    ave_occup: PositiveFloat = Field(..., validation_alias="AveOccup")
    latitude: float = Field(..., validation_alias="Latitude")
    longitude: float = Field(..., validation_alias="Longitude")


@app.route("/", methods=["GET"])
def california_index():
    return render_template("index.html")


@app.route("/predict/", methods=["POST"])
def local_model_result():
    form_query = FormQuery(**request.form.to_dict(flat=True))

    reg = MODEL.predict(
        [
            [
                form_query.med_inc,
                form_query.house_age,
                form_query.ave_rooms,
                form_query.ave_bedrms,
                form_query.population,
                form_query.ave_occup,
                form_query.latitude,
                form_query.longitude,
            ]
        ]
    )[0]
    return render_template("prediction.html", price=reg * PRICE_BASE)


@app.route("/predict_from_api/", methods=["POST"])
def api_result():
    model_list = requests.get(f"http://{API_HOST}:8000/model/list/").json()
    if len(model_list) == 0:
        raise Exception("No model could be retrieved from the model registry")

    best_model = sorted(model_list, key=lambda d: d["mse"], reverse=True)[0]
    app.logger.debug(f"Best model retrieved : {best_model}")

    api_response = requests.post(
        f"http://{API_HOST}:8000/model/predict/",
        json={
            **{"train_id": best_model["train_id"]},
            **FormQuery(**request.form.to_dict(flat=True)).model_dump(),
        },
    )

    response = api_response.json()
    app.logger.debug(response)

    return render_template("prediction.html", price=response["reg"] * PRICE_BASE)


if __name__ == "__main__":
    app.debug = True
    app.run(host="0.0.0.0", port=5000, debug=True)