Spaces:
Sleeping
Sleeping
from keras import ( | |
saving, | |
preprocessing, | |
applications | |
) | |
import fastapi | |
from fastapi import UploadFile, File, HTTPException | |
from PIL import Image | |
import io | |
import time | |
import numpy as np | |
app = fastapi.FastAPI() | |
model = saving.load_model("hf://Yumeng-Liu/trash-classifier") | |
CLASSES = ['biological', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] | |
THRESHOLD = 5.0e-1 | |
def get_prediction(img: Image) -> str: | |
img = img.resize((224, 224)) | |
img_array = preprocessing.image.img_to_array(img) | |
img_array = np.expand_dims(img_array, axis=0) # Add an extra dimension to match the model's input shape | |
img_array = applications.mobilenet_v2.preprocess_input(img_array) | |
prediction_array = model.predict(img_array) | |
predicted_class_idx = np.argmax(prediction_array[0]) | |
if prediction_array[0][predicted_class_idx] > THRESHOLD: | |
prediction = CLASSES[predicted_class_idx] | |
else: | |
prediction = "none" | |
return prediction | |
def read_root(): | |
return {"Hello": "World"} | |
async def predict(received_image: UploadFile = File(...)): | |
try: | |
contents = received_image.file.read() | |
# Open the binary data as an image | |
image = Image.open(io.BytesIO(contents)) | |
print("Image received") | |
# You can now work with the `image` object | |
print(image.format, image.size, image.mode) # Example: JPEG (1920, 1080) RGB | |
print("") | |
prediction_result = get_prediction(image) | |
print(prediction_result) | |
# Perform further processing, e.g., save it, analyze it, etc. | |
return { | |
"result": prediction_result | |
} | |
except Exception as e: | |
print(e) | |
raise HTTPException(status_code=500, detail='Something went wrong') | |
finally: | |
received_image.file.close() | |
if __name__ == "__main__": | |
print("Starting app") | |
while True: | |
time.sleep(10) | |