Yumeng Liu
new model
d6e48e1
from keras import (
saving,
preprocessing,
applications
)
import fastapi
from fastapi import UploadFile, File, HTTPException
from PIL import Image
import io
import time
import numpy as np
app = fastapi.FastAPI()
model = saving.load_model("hf://Yumeng-Liu/trash-classifier")
CLASSES = ['biological', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
THRESHOLD = 5.0e-1
def get_prediction(img: Image) -> str:
img = img.resize((224, 224))
img_array = preprocessing.image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Add an extra dimension to match the model's input shape
img_array = applications.mobilenet_v2.preprocess_input(img_array)
prediction_array = model.predict(img_array)
predicted_class_idx = np.argmax(prediction_array[0])
if prediction_array[0][predicted_class_idx] > THRESHOLD:
prediction = CLASSES[predicted_class_idx]
else:
prediction = "none"
return prediction
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/predict-image")
async def predict(received_image: UploadFile = File(...)):
try:
contents = received_image.file.read()
# Open the binary data as an image
image = Image.open(io.BytesIO(contents))
print("Image received")
# You can now work with the `image` object
print(image.format, image.size, image.mode) # Example: JPEG (1920, 1080) RGB
print("")
prediction_result = get_prediction(image)
print(prediction_result)
# Perform further processing, e.g., save it, analyze it, etc.
return {
"result": prediction_result
}
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail='Something went wrong')
finally:
received_image.file.close()
if __name__ == "__main__":
print("Starting app")
while True:
time.sleep(10)