import numpy as np import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from PIL import Image st.set_page_config(page_title="Garbage Classification") # CNN Model Definition class SimpleCNN(nn.Module): def __init__(self, num_classes, input_channels=3): super().__init__() # Convolutional layers self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=0) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=0) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=0) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=0) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # Dense layers self.fc1 = nn.Linear(256 * 12 * 12, 512) self.dropout1 = nn.Dropout(0.5) self.fc2 = nn.Linear(512, 512) self.dropout2 = nn.Dropout(0.5) self.fc3 = nn.Linear(512, num_classes) def forward(self, x): # Conv blocks x = F.relu(self.conv1(x)) x = self.pool1(x) x = F.relu(self.conv2(x)) x = self.pool2(x) x = F.relu(self.conv3(x)) x = self.pool3(x) x = F.relu(self.conv4(x)) x = self.pool4(x) # Dense layers x = self.flatten(x) x = F.relu(self.fc1(x)) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) x = self.fc3(x) return x # Class names CLASS_NAMES = [ "battery", "biological", "cardboard", "clothes", "glass", "metal", "paper", "plastic", "shoes", "trash", ] # Cache the model loading @st.cache_resource def load_model(): """Load the trained model""" device = torch.device("cpu") model = SimpleCNN(num_classes=10) model = nn.DataParallel(model) try: model.load_state_dict(torch.load("best_model.pth", map_location=device)) model.eval() return model, device except Exception as e: st.error(f"Error loading model: {e}") return None, device def preprocess_image(image): """Preprocess uploaded image""" transform = T.Compose( [ T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) image_tensor = transform(image).unsqueeze(0) return image_tensor def predict_image(image, model, device): """Make prediction on image""" # Preprocess image input_tensor = preprocess_image(image).to(device) # Make prediction with torch.no_grad(): outputs = model(input_tensor) probabilities = F.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) predicted_class = CLASS_NAMES[predicted_idx.item()] confidence_score = confidence.item() all_probabilities = probabilities.cpu().numpy().flatten() return predicted_class, confidence_score, all_probabilities def get_confidence_color(confidence): """Get color class based on confidence score""" if confidence >= 0.7: return "confidence-high" elif confidence >= 0.4: return "confidence-medium" else: return "confidence-low" def main(): # Load model model, device = load_model() # File uploader st.header("Garbage Classification") uploaded_file = st.file_uploader( "Choose an image file", type=["jpg", "jpeg", "png"], ) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file).convert("RGB") col1, col2 = st.columns([1, 1]) with col1: st.image(image, caption="Uploaded Image", use_container_width=True) # Make prediction with st.spinner("🔍 Analyzing image..."): predicted_class, confidence, probabilities = predict_image( image, model, device ) sorted_indices = np.argsort(probabilities)[::-1] container = col2.container(border=True) for i, idx in enumerate(sorted_indices): class_name = CLASS_NAMES[idx] prob = probabilities[idx] container.write(f"{class_name.title()}: {prob:.1%}") if __name__ == "__main__": main()