O S I H commited on
Commit
b7b10e6
·
1 Parent(s): 6eec5de
Files changed (1) hide show
  1. api.py +2 -4
api.py CHANGED
@@ -5,15 +5,12 @@ from PIL import Image
5
  import io
6
  import tensorflow as tf
7
 
8
- app = FastAPI()
9
 
10
- # Load the model
11
- # model = from_pretrained_keras("MissingBreath/recycle-garbage-model")
12
- # model = from_pretrained_keras("./recycle-garbage-model")
13
  model = tf.keras.models.load_model('_9217')
14
  # Class labels
15
  # class_labels = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
16
 
 
17
  @app.post("/classify")
18
  async def classify(image: UploadFile = File(...)):
19
  if image is not None:
@@ -23,6 +20,7 @@ async def classify(image: UploadFile = File(...)):
23
  img_array = np.expand_dims(img_array, axis=0)
24
  predictions = model.predict(img_array)
25
  predicted_class_idx = np.argmax(predictions)
 
26
  # predicted_class = class_labels[predicted_class_idx]
27
  # return {"prediction": predicted_class}
28
  return {"prediction": predicted_class_idx}
 
5
  import io
6
  import tensorflow as tf
7
 
 
8
 
 
 
 
9
  model = tf.keras.models.load_model('_9217')
10
  # Class labels
11
  # class_labels = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
12
 
13
+ app = FastAPI()
14
  @app.post("/classify")
15
  async def classify(image: UploadFile = File(...)):
16
  if image is not None:
 
20
  img_array = np.expand_dims(img_array, axis=0)
21
  predictions = model.predict(img_array)
22
  predicted_class_idx = np.argmax(predictions)
23
+ predicted_class_idx = int(predicted_class_idx)
24
  # predicted_class = class_labels[predicted_class_idx]
25
  # return {"prediction": predicted_class}
26
  return {"prediction": predicted_class_idx}