Spaces:
Paused
Paused
import os | |
import pickle | |
import requests | |
from dotenv import load_dotenv | |
from flask import Flask, render_template, request | |
from pydantic import BaseModel, Field, PositiveFloat | |
app = Flask(__name__) | |
MODEL = pickle.load(open("model.sav", "rb")) | |
PRICE_BASE = 10**5 | |
load_dotenv(override=False) | |
API_HOST = os.environ.get("API_HOST", "localhost") | |
class FormQuery(BaseModel): | |
med_inc: PositiveFloat = Field(..., validation_alias="MedInc") | |
house_age: PositiveFloat = Field(..., validation_alias="HouseAge") | |
ave_rooms: PositiveFloat = Field(..., validation_alias="AveRooms") | |
ave_bedrms: PositiveFloat = Field(..., validation_alias="AveBedrms") | |
population: PositiveFloat = Field(..., validation_alias="Population") | |
ave_occup: PositiveFloat = Field(..., validation_alias="AveOccup") | |
latitude: float = Field(..., validation_alias="Latitude") | |
longitude: float = Field(..., validation_alias="Longitude") | |
def california_index(): | |
return render_template("index.html") | |
def local_model_result(): | |
form_query = FormQuery(**request.form.to_dict(flat=True)) | |
reg = MODEL.predict( | |
[ | |
[ | |
form_query.med_inc, | |
form_query.house_age, | |
form_query.ave_rooms, | |
form_query.ave_bedrms, | |
form_query.population, | |
form_query.ave_occup, | |
form_query.latitude, | |
form_query.longitude, | |
] | |
] | |
)[0] | |
return render_template("prediction.html", price=reg * PRICE_BASE) | |
def api_result(): | |
model_list = requests.get(f"http://{API_HOST}:8000/model/list/").json() | |
if len(model_list) == 0: | |
raise Exception("No model could be retrieved from the model registry") | |
best_model = sorted(model_list, key=lambda d: d["mse"], reverse=True)[0] | |
app.logger.debug(f"Best model retrieved : {best_model}") | |
api_response = requests.post( | |
f"http://{API_HOST}:8000/model/predict/", | |
json={ | |
**{"train_id": best_model["train_id"]}, | |
**FormQuery(**request.form.to_dict(flat=True)).model_dump(), | |
}, | |
) | |
response = api_response.json() | |
app.logger.debug(response) | |
return render_template("prediction.html", price=response["reg"] * PRICE_BASE) | |
if __name__ == "__main__": | |
app.debug = True | |
app.run(host="0.0.0.0", port=5000, debug=True) | |