|
from sklearn import tree |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import numpy as np |
|
from typing import List |
|
from joblib import load |
|
|
|
class InputData(BaseModel): |
|
data: List[float] |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
def build_decision_tree(): |
|
|
|
|
|
model = tree.DecisionTreeClassifier(criterion="entropy", max_depth=10) |
|
model = load( |
|
"miarbol.pkl" |
|
) |
|
return model |
|
|
|
|
|
model = build_decision_tree() |
|
|
|
|
|
|
|
@app.post("/predict/") |
|
async def predict(data: InputData): |
|
print(f"Data: {data}") |
|
global model |
|
try: |
|
|
|
input_data = np.array(data.data).reshape( |
|
1, -1 |
|
) |
|
prediction = model.predict(input_data).round() |
|
return {"prediction": prediction.tolist()} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|