Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# encoding: utf-8 | |
from fastapi import FastAPI, Form, Depends, Request | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import pickle | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Replace with the list of allowed origins for production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
model_file = open('insurance_model.pkl', 'rb') | |
model = pickle.load(model_file, encoding='bytes') | |
class Msg(BaseModel): | |
msg: str | |
class Req(BaseModel): | |
age: int | |
sex: int | |
smoker: int | |
bmi: float | |
children: int | |
region: int | |
class Resp(BaseModel): | |
age: int | |
sex: str | |
smoker: str | |
bmi: float | |
children: int | |
region: str | |
insurance_cost: float | |
async def root(): | |
return {"message": "Hello World. Welcome to FastAPI!"} | |
def form_req(age: str = Form(...), sex: str = Form(...), smoker: str = Form(...), | |
bmi: str = Form(...), children: str = Form(...), region: str = Form(...)): | |
sBmi = bmi.replace(",", ".") | |
return Req(age=int(age), sex=int(sex), smoker=int(smoker), bmi=float(sBmi), children=int(children), region=int(region)) | |
async def demo_get(): | |
return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"} | |
async def demo_post(inp: Msg): | |
return {"message": inp.msg.upper()} | |
async def demo_get_path_id(path_id: int): | |
return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"} | |
async def predict(path_id: int): | |
return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"} | |
def get_region_name(region_code): | |
region_mapping = { | |
0: "Northeast", | |
1: "Northwest", | |
2: "Southeast", | |
3: "Southwest" | |
} | |
return region_mapping.get(region_code, "Unknown") | |
async def predict(request: Request, requess: Req = Depends(form_req)): | |
''' | |
Predict the insurance cost based on user inputs | |
and render the result to the html page | |
''' | |
age = requess.age | |
sex = requess.sex | |
smoker = requess.smoker | |
bmi = requess.bmi | |
children = requess.children | |
region = requess.region | |
data = [] | |
data.append(int(age)) | |
data.extend([int(sex)]) | |
data.extend([float(bmi)]) | |
data.extend([int(children)]) | |
data.extend([int(smoker)]) | |
data.extend([int(region)]) | |
prediction = model.predict([data]) | |
output = round(prediction[0], 2) | |
sex = "Male" if requess.sex == 1 else "Female" | |
smoker = "Yes" if requess.smoker == 1 else "No" | |
# Render index.html with prediction results | |
json_compatible_resp_data = jsonable_encoder(Resp(age=requess.age, sex=sex, smoker=smoker, | |
bmi=requess.bmi, children=requess.children, region=get_region_name(requess.region), insurance_cost=output)) | |
return JSONResponse(content=json_compatible_resp_data) | |