Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Form, Depends, Request | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import pickle | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load the pre-trained model | |
with open('insurance_model.pkl', 'rb') as model_file: | |
model = pickle.load(model_file, encoding='bytes') | |
class Req(BaseModel): | |
age: int | |
sex: int | |
smoker: int | |
bmi: float | |
children: int | |
region: int | |
class Resp(BaseModel): | |
age: int | |
sex: str | |
smoker: str | |
bmi: float | |
children: int | |
region: str | |
insurance_cost: float | |
async def root(): | |
return {"message": "Hello World. Welcome to FastAPI!"} | |
def form_req(age: str = Form(...), sex: str = Form(...), smoker: str = Form(...), | |
bmi: str = Form(...), children: str = Form(...), region: str = Form(...)): | |
sBmi = bmi.replace(",", ".") | |
return Req(age=int(age), sex=int(sex), smoker=int(smoker), bmi=float(sBmi), children=int(children), region=int(region)) | |
def get_region_name(region_code): | |
region_mapping = { | |
0: "Northeast", | |
1: "Northwest", | |
2: "Southeast", | |
3: "Southwest" | |
} | |
return region_mapping.get(region_code, "Unknown") | |
async def predict(request: Request, requess: Req = Depends(form_req)): | |
''' | |
Predict the insurance cost based on user inputs | |
and render the result to the html page | |
''' | |
age = requess.age | |
sex = requess.sex | |
smoker = requess.smoker | |
bmi = requess.bmi | |
children = requess.children | |
region = requess.region | |
data = [] | |
data.append(int(age)) | |
data.extend([int(sex)]) | |
data.extend([float(bmi)]) | |
data.extend([int(children)]) | |
data.extend([int(smoker)]) | |
data.extend([int(region)]) | |
prediction = model.predict([data]) | |
output = round(prediction[0], 2) | |
sex = "Male" if requess.sex == 1 else "Female" | |
smoker = "Yes" if requess.smoker == 1 else "No" | |
# Render index.html with prediction results | |
json_compatible_resp_data = jsonable_encoder(Resp( | |
age=requess.age, | |
sex=sex, | |
smoker=smoker, | |
bmi=requess.bmi, | |
children=requess.children, | |
region=get_region_name(requess.region), | |
insurance_cost=output | |
)) | |
return JSONResponse(content=json_compatible_resp_data) | |