Spaces:
Sleeping
Sleeping
| from keras import ( | |
| saving, | |
| preprocessing, | |
| applications | |
| ) | |
| import fastapi | |
| from fastapi import UploadFile, File, HTTPException | |
| from PIL import Image | |
| import io | |
| import time | |
| import numpy as np | |
| app = fastapi.FastAPI() | |
| model = saving.load_model("hf://Yumeng-Liu/trash-classifier") | |
| CLASSES = ['biological', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] | |
| THRESHOLD = 5.0e-1 | |
| def get_prediction(img: Image) -> str: | |
| img = img.resize((224, 224)) | |
| img_array = preprocessing.image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) # Add an extra dimension to match the model's input shape | |
| img_array = applications.mobilenet_v2.preprocess_input(img_array) | |
| prediction_array = model.predict(img_array) | |
| predicted_class_idx = np.argmax(prediction_array[0]) | |
| if prediction_array[0][predicted_class_idx] > THRESHOLD: | |
| prediction = CLASSES[predicted_class_idx] | |
| else: | |
| prediction = "none" | |
| return prediction | |
| def read_root(): | |
| return {"Hello": "World"} | |
| async def predict(received_image: UploadFile = File(...)): | |
| try: | |
| contents = received_image.file.read() | |
| # Open the binary data as an image | |
| image = Image.open(io.BytesIO(contents)) | |
| print("Image received") | |
| # You can now work with the `image` object | |
| print(image.format, image.size, image.mode) # Example: JPEG (1920, 1080) RGB | |
| print("") | |
| prediction_result = get_prediction(image) | |
| print(prediction_result) | |
| # Perform further processing, e.g., save it, analyze it, etc. | |
| return { | |
| "result": prediction_result | |
| } | |
| except Exception as e: | |
| print(e) | |
| raise HTTPException(status_code=500, detail='Something went wrong') | |
| finally: | |
| received_image.file.close() | |
| if __name__ == "__main__": | |
| print("Starting app") | |
| while True: | |
| time.sleep(10) | |