recycling-ai / app.py
Tan Le
Add more example images
e3b08f1
raw
history blame
1.81 kB
import numpy as np
import tensorflow as tf
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
import json
model_path = hf_hub_download(
repo_id="tancnle/smart-recycling",
filename="model_baseline.h5",
use_auth_token="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",
)
dim = (299, 299)
def read_image(image):
image = tf.convert_to_tensor(image)
image.set_shape([None, None, 3])
image = tf.image.resize(images=image, size=dim)
image = image / 125.0 - 1
return image
def infer(model, image_tensor):
predictions = model.predict(tf.expand_dims(image_tensor, axis=0))
labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
predictions = list(map(float, predictions[0]))
output = dict(zip(labels, predictions))
return output
def top_3_accuracy(ytrue, ypred):
return tf.keras.metrics.sparse_top_k_categorical_accuracy(ytrue, ypred, k=3)
def classify(input_image):
model = tf.keras.models.load_model(
model_path,
custom_objects={"top_3_accuracy": top_3_accuracy},
)
image_tensor = read_image(input_image)
predictions = infer(model, image_tensor)
return predictions
title = "Classify Trash"
description = "Upload an image or select from examples to classify trash."
article = "<div style='text-align: center;'>Space by Tan Le</div>"
examples = [
"images/cardboard.jpeg",
"images/cigarette_butt.jpeg",
"images/masks.jpeg",
"images/metal_objects.jpeg",
"images/paper.jpeg",
"images/plastic.jpeg",
"images/spray_cans.jpeg",
"images/syringe.jpeg",
]
demo = gr.Interface(
classify,
inputs=gr.inputs.Image(),
outputs="label",
examples=examples,
title=title,
description=description,
article=article,
)
demo.launch()