Spaces:
Runtime error
Runtime error
import numpy as np | |
import tensorflow as tf | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import json | |
model_path = hf_hub_download( | |
repo_id="tancnle/smart-recycling", | |
filename="model_baseline.h5", | |
use_auth_token="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", | |
) | |
dim = (299, 299) | |
def read_image(image): | |
image = tf.convert_to_tensor(image) | |
image.set_shape([None, None, 3]) | |
image = tf.image.resize(images=image, size=dim) | |
image = image / 125.0 - 1 | |
return image | |
def infer(model, image_tensor): | |
predictions = model.predict(tf.expand_dims(image_tensor, axis=0)) | |
labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] | |
predictions = list(map(float, predictions[0])) | |
output = dict(zip(labels, predictions)) | |
return output | |
def top_3_accuracy(ytrue, ypred): | |
return tf.keras.metrics.sparse_top_k_categorical_accuracy(ytrue, ypred, k=3) | |
def classify(input_image): | |
model = tf.keras.models.load_model( | |
model_path, | |
custom_objects={"top_3_accuracy": top_3_accuracy}, | |
) | |
image_tensor = read_image(input_image) | |
predictions = infer(model, image_tensor) | |
return predictions | |
title = "Classify Trash" | |
description = "Upload an image or select from examples to classify trash." | |
article = "<div style='text-align: center;'>Space by Tan Le</div>" | |
examples = [ | |
"images/cardboard.jpeg", | |
"images/cigarette_butt.jpeg", | |
"images/masks.jpeg", | |
"images/metal_objects.jpeg", | |
"images/paper.jpeg", | |
"images/plastic.jpeg", | |
"images/spray_cans.jpeg", | |
"images/syringe.jpeg", | |
] | |
demo = gr.Interface( | |
classify, | |
inputs=gr.inputs.Image(), | |
outputs="label", | |
examples=examples, | |
title=title, | |
description=description, | |
article=article, | |
) | |
demo.launch() | |