Si2469 commited on
Commit
5e914b4
1 Parent(s): 9f6ad7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -31
app.py CHANGED
@@ -1,48 +1,64 @@
1
- # app.py
 
 
 
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- import joblib
5
  import numpy as np
6
- from sklearn.tree import DecisionTreeClassifier
7
-
8
- app = FastAPI()
9
 
10
- # Cargar el modelo desde el archivo joblib
11
- try:
12
- with open('miarbol.pkl', 'rb') as f:
13
- miarbol = joblib.load(f)
14
- except Exception as e:
15
- raise RuntimeError(f"Error al cargar el modelo: {str(e)}")
16
 
 
17
  class InputData(BaseModel):
18
  Gender: int
19
  Age: float
20
- PlayTimeHours: float
21
  InGamePurchases: int
22
  SessionsPerWeek: int
23
  AvgSessionDurationMinutes: int
24
  PlayerLevel: int
25
  AchievementsUnlocked: int
26
 
27
- @app.post("/predict")
28
- def predict(input_data: InputData):
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
- # Convertir los datos de entrada a un array numpy para hacer la predicción
31
- input_array = np.array([[
32
- input_data.Gender,
33
- input_data.Age,
34
- input_data.PlayTimeHours,
35
- input_data.InGamePurchases,
36
- input_data.SessionsPerWeek,
37
- input_data.AvgSessionDurationMinutes,
38
- input_data.PlayerLevel,
39
- input_data.AchievementsUnlocked
40
- ]])
41
-
42
- # Realizar la predicción
43
- prediction = miarbol.predict(input_array)
44
-
45
- # Devolver la predicción como JSON
46
- return {"EngagementLevel": prediction[0]}
47
  except Exception as e:
48
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.tree import DecisionTreeClassifier
3
+ from sklearn.metrics import confusion_matrix, classification_report
4
+ import matplotlib.pyplot as plt
5
+ import pickle
6
  from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
+ from typing import List
9
  import numpy as np
 
 
 
10
 
11
+ # Cargar el modelo desde el archivo .pkl usando joblib
12
+ import joblib
 
 
 
 
13
 
14
+ # Define la estructura de entrada esperada para FastAPI
15
  class InputData(BaseModel):
16
  Gender: int
17
  Age: float
18
+ PlayTimeHours: int
19
  InGamePurchases: int
20
  SessionsPerWeek: int
21
  AvgSessionDurationMinutes: int
22
  PlayerLevel: int
23
  AchievementsUnlocked: int
24
 
25
+ # Iniciar la aplicación FastAPI
26
+ app = FastAPI()
27
+
28
+ # Función para cargar el modelo guardado con joblib
29
+ def load_model():
30
+ with open('miarbol.pkl', 'rb') as f:
31
+ model = joblib.load(f)
32
+ return model
33
+
34
+ # Cargar el modelo al iniciar la aplicación
35
+ miarbol = load_model()
36
+
37
+ # Ruta de predicción
38
+ @app.post("/predict/")
39
+ async def predict(data: InputData):
40
  try:
41
+ # Convertir los datos de entrada en un array de NumPy para la predicción
42
+ X_input = np.array([[data.Gender, data.Age, data.PlayTimeHours, data.InGamePurchases,
43
+ data.SessionsPerWeek, data.AvgSessionDurationMinutes,
44
+ data.PlayerLevel, data.AchievementsUnlocked]])
45
+
46
+ # Realizar la predicción utilizando el modelo cargado
47
+ prediction = miarbol.predict(X_input)
48
+
49
+ # Retornar el resultado de la predicción
50
+ return {"prediction": int(prediction[0])} # Convertir a int para asegurar que sea JSON serializable
51
+
 
 
 
 
 
 
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))
54
+
55
+ # Ruta para verificar el estado de la API
56
+ @app.get("/status/")
57
+ async def status():
58
+ return {"status": "Modelo cargado y API en funcionamiento"}
59
+
60
+ # Ejemplo de cómo utilizar la API
61
+ if __name__ == "__main__":
62
+ import uvicorn
63
+ uvicorn.run(app, host="0.0.0.0", port=8000)
64
+