Spaces:
Sleeping
Sleeping
""" | |
IM 417 Classification | |
""" | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
import tempfile | |
from streamlit_drawable_canvas import st_canvas | |
from huggingface_hub import hf_hub_download | |
from tensorflow import keras | |
from huggingface_hub import from_pretrained_keras | |
from urllib.request import urlretrieve | |
# Load the model (using a Streamlit caching mechanism for efficiency) | |
# Allow model mutation for prediction | |
def load_model(): | |
with tempfile.NamedTemporaryFile(suffix=".keras") as temp_file: | |
url = "https://huggingface.co/tomeheya/IM-417-128x128-classification/raw/main/IM_417_128.keras" | |
urlretrieve(url, temp_file.name) | |
return keras.models.load_model(temp_file.name) | |
# Function to preprocess the input image | |
def preprocess_image(image): | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
image = cv2.bitwise_not(image) | |
image = np.expand_dims(image, axis=-1) # Add channel dimension | |
image = image / 255.0 # Normalize pixel values | |
return image | |
# Function to perform inference | |
def predict(image): | |
model = load_model() | |
preprocessed_image = preprocess_image(image) | |
preprocessed_image = np.expand_dims(preprocessed_image, axis=0) # Add batch dimension | |
prediction = model.predict(preprocessed_image) | |
predicted_label = decode_prediction(prediction) | |
return predicted_label | |
# Function to decode the model prediction | |
def decode_prediction(prediction): | |
# Decode the prediction based on your model's output format | |
# For example, if your model outputs class probabilities: | |
class_idx = np.argmax(prediction) | |
predicted_label = "IM-417 sign number :: " + str(class_idx + 1) | |
return predicted_label | |
# Initialize the canvas | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", | |
stroke_width=3, | |
stroke_color="rgba(0, 0, 0, 1)", | |
background_color="rgba(255, 255, 255, 1)", | |
height=128, | |
width=128, | |
update_streamlit=True, | |
key="canvas" | |
) | |
# Add a submit button | |
if st.button("Submit"): | |
if canvas_result.json_data: | |
image_from_canvas = canvas_result.image_data | |
st.text(f"Original data: {image_from_canvas}") | |
preprocessed_image = preprocess_image(image_from_canvas) | |
predicted_label = predict(preprocessed_image) | |
st.text(f"Predicted label: {predicted_label}") | |
else: | |
st.text("Please draw an image on the canvas before submitting.") | |
# Streamlit app layout | |
st.title("IM 417 Classification") | |