storesource's picture
Update app.py
ba67a5f verified
raw
history blame
No virus
2.51 kB
"""
IM 417 Classification
"""
import cv2
import numpy as np
import streamlit as st
import tempfile
from streamlit_drawable_canvas import st_canvas
from huggingface_hub import hf_hub_download
from tensorflow import keras
from huggingface_hub import from_pretrained_keras
from urllib.request import urlretrieve
# Load the model (using a Streamlit caching mechanism for efficiency)
@st.cache(allow_output_mutation=True) # Allow model mutation for prediction
def load_model():
with tempfile.NamedTemporaryFile(suffix=".keras") as temp_file:
url = "https://huggingface.co/tomeheya/IM-417-128x128-classification/raw/main/IM_417_128.keras"
urlretrieve(url, temp_file.name)
return keras.models.load_model(temp_file.name)
# Function to preprocess the input image
def preprocess_image(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.bitwise_not(image)
image = np.expand_dims(image, axis=-1) # Add channel dimension
image = image / 255.0 # Normalize pixel values
return image
# Function to perform inference
def predict(image):
model = load_model()
preprocessed_image = preprocess_image(image)
preprocessed_image = np.expand_dims(preprocessed_image, axis=0) # Add batch dimension
prediction = model.predict(preprocessed_image)
predicted_label = decode_prediction(prediction)
return predicted_label
# Function to decode the model prediction
def decode_prediction(prediction):
# Decode the prediction based on your model's output format
# For example, if your model outputs class probabilities:
class_idx = np.argmax(prediction)
predicted_label = "IM-417 sign number :: " + str(class_idx + 1)
return predicted_label
# Initialize the canvas
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=3,
stroke_color="rgba(0, 0, 0, 1)",
background_color="rgba(255, 255, 255, 1)",
height=128,
width=128,
update_streamlit=True,
key="canvas"
)
# Add a submit button
if st.button("Submit"):
if canvas_result.json_data:
image_from_canvas = canvas_result.image_data
st.text(f"Original data: {image_from_canvas}")
preprocessed_image = preprocess_image(image_from_canvas)
predicted_label = predict(preprocessed_image)
st.text(f"Predicted label: {predicted_label}")
else:
st.text("Please draw an image on the canvas before submitting.")
# Streamlit app layout
st.title("IM 417 Classification")