JairoDanielMT commited on
Commit
3bb9a26
1 Parent(s): 9edaeee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -1,31 +1,40 @@
1
  from fastapi import FastAPI, HTTPException
2
- from tensorflow.keras.models import model_from_json
3
  from pydantic import BaseModel
4
  import numpy as np
5
 
 
6
  class InputData(BaseModel):
7
- data: list
8
 
9
  app = FastAPI()
 
10
 
 
11
  def load_model():
12
  try:
13
- with open("model.json", 'r') as json_file:
14
- loaded_model_json = json_file.read()
 
15
  loaded_model = model_from_json(loaded_model_json)
16
  loaded_model.load_weights("model.h5")
17
  loaded_model.compile(loss='mean_squared_error', optimizer='adam', metrics=['binary_accuracy'])
18
  return loaded_model
19
  except Exception as e:
20
- print(f"Error cargando el modelo: {str(e)}")
21
- raise
22
-
23
- model = load_model()
24
 
 
25
  @app.post("/predict/")
26
  async def predict(data: InputData):
 
 
 
 
 
27
  try:
28
- input_data = np.array(data.data).reshape(1, -1)
 
29
  prediction = model.predict(input_data).round()
30
  return {"prediction": prediction.tolist()}
31
  except Exception as e:
 
1
  from fastapi import FastAPI, HTTPException
2
+ from keras.models import model_from_json
3
  from pydantic import BaseModel
4
  import numpy as np
5
 
6
+ # Definición del modelo de datos de entrada
7
  class InputData(BaseModel):
8
+ data: list # Asumiendo que la entrada es una lista de características numéricas
9
 
10
  app = FastAPI()
11
+ model = None # Inicializa el modelo como None
12
 
13
+ # Carga del modelo
14
  def load_model():
15
  try:
16
+ json_file = open("model.json", 'r')
17
+ loaded_model_json = json_file.read()
18
+ json_file.close()
19
  loaded_model = model_from_json(loaded_model_json)
20
  loaded_model.load_weights("model.h5")
21
  loaded_model.compile(loss='mean_squared_error', optimizer='adam', metrics=['binary_accuracy'])
22
  return loaded_model
23
  except Exception as e:
24
+ print(f"Error al cargar el modelo: {e}")
25
+ return None
 
 
26
 
27
+ # Ruta de predicción
28
  @app.post("/predict/")
29
  async def predict(data: InputData):
30
+ global model
31
+ if model is None:
32
+ model = load_model()
33
+ if model is None:
34
+ raise HTTPException(status_code=500, detail="Model could not be loaded")
35
  try:
36
+ # Convertir la lista de entrada a un array de NumPy para la predicción
37
+ input_data = np.array(data.data).reshape(1, -1) # Asumiendo que la entrada debe ser de forma (1, num_features)
38
  prediction = model.predict(input_data).round()
39
  return {"prediction": prediction.tolist()}
40
  except Exception as e: