trash_sort / src /predict.py
chiichann's picture
Initial commit for TrashSort
3f98cbd
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
IMG_SIZE = 128
# Removed 'trash' from the list
CLASS_NAMES = ['cardboard', 'glass', 'metal', 'paper', 'plastic']
def preprocess_image(img):
img = img.resize((IMG_SIZE, IMG_SIZE))
img = img_to_array(img)
img = img / 255.0
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
def predict_image(model, image):
processed = preprocess_image(image)
prediction = model.predict(processed)
class_idx = np.argmax(prediction)
confidence = float(np.max(prediction))
label = CLASS_NAMES[class_idx]
return label, confidence, class_idx, processed