storesource's picture
Update app.py
f1f8687 verified
"""
IM 417 Classification
"""
import cv2
import numpy as np
import os
import streamlit as st
import tempfile
from streamlit_drawable_canvas import st_canvas
from tensorflow import keras
from huggingface_hub import from_pretrained_keras
from transformers import AutoModel
# Load the model (using a Streamlit caching mechanism for efficiency)
#@st.cache(allow_output_mutation=True) # Allow model mutation for prediction
def load_model():
try:
model = from_pretrained_keras("tomeheya/IM-417-128x128-classification", force_download=True)
#model = AutoModel.from_pretrained("tomeheya/IM-417-128x128-classification")
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Function to preprocess the input image
def preprocess_image(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.bitwise_not(image)
image = np.expand_dims(image, axis=-1) # Add channel dimension
image = image / 255.0 # Normalize pixel values
return image
# Function to perform inference
def predict(image):
model = load_model()
if not model:
return "Model couldn't be loaded !!!"
preprocessed_image = preprocess_image(image)
preprocessed_image = np.expand_dims(preprocessed_image, axis=0) # Add batch dimension
prediction = model.predict(preprocessed_image)
predicted_label = decode_prediction(prediction)
return predicted_label
# Function to decode the model prediction
def decode_prediction(prediction):
# Decode the prediction based on your model's output format
# For example, if your model outputs class probabilities:
class_idx = np.argmax(prediction)
predicted_label = "IM-417 sign number :: " + str(class_idx + 1)
return predicted_label
# Initialize the canvas
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=3,
stroke_color="rgba(0, 0, 0, 1)",
background_color="rgba(255, 255, 255, 1)",
height=128,
width=128,
update_streamlit=True,
key="canvas"
)
# Add a submit button
if st.button("Submit"):
if canvas_result.json_data:
image_from_canvas = canvas_result.image_data
st.text(f"Original data: {len(image_from_canvas)}")
preprocessed_image = preprocess_image(image_from_canvas)
predicted_label = predict(preprocessed_image)
st.text(f"Predicted label: {predicted_label}")
else:
st.text("Please draw an image on the canvas before submitting.")
# Streamlit app layout
st.title("IM 417 Classification")