API_titanic1 / api.py
Emmanuel26's picture
Update api.py
296b2e8 verified
raw
history blame contribute delete
863 Bytes
from typing import Union
from fastapi import FastAPI,Request
from pydantic import BaseModel
import pickle
import sklearn
import joblib
import numpy as np
app = FastAPI()
class InputData(BaseModel):
stay_class: int # 1, 2, 3
sex: int
ticket_price: float
def load_model():
model_path = "RandomForestClassifier_model.pkl"
with open(model_path, 'rb') as file:
model = pickle.load(file)
return model
# Chargez le modèle
model = load_model()
@app.post("/predict")
def predict(input_data: InputData):
data = np.array([[input_data.stay_class, input_data.sex, input_data.ticket_price]])
predictions = model.predict(data)
survival = predictions[0]
survival = int(survival)
return {"Survival Prediction": survival}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, reload=True)