shrestha-prabin's picture
Update Requirements
51a34d1
import numpy as np
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
st.set_page_config(page_title="Garbage Classification")
# CNN Model Definition
class SimpleCNN(nn.Module):
def __init__(self, num_classes, input_channels=3):
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=0)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=0)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=0)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=0)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
# Dense layers
self.fc1 = nn.Linear(256 * 12 * 12, 512)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 512)
self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(512, num_classes)
def forward(self, x):
# Conv blocks
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.pool3(x)
x = F.relu(self.conv4(x))
x = self.pool4(x)
# Dense layers
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return x
# Class names
CLASS_NAMES = [
"battery",
"biological",
"cardboard",
"clothes",
"glass",
"metal",
"paper",
"plastic",
"shoes",
"trash",
]
# Cache the model loading
@st.cache_resource
def load_model():
"""Load the trained model"""
device = torch.device("cpu")
model = SimpleCNN(num_classes=10)
model = nn.DataParallel(model)
try:
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.eval()
return model, device
except Exception as e:
st.error(f"Error loading model: {e}")
return None, device
def preprocess_image(image):
"""Preprocess uploaded image"""
transform = T.Compose(
[
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
image_tensor = transform(image).unsqueeze(0)
return image_tensor
def predict_image(image, model, device):
"""Make prediction on image"""
# Preprocess image
input_tensor = preprocess_image(image).to(device)
# Make prediction
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
predicted_class = CLASS_NAMES[predicted_idx.item()]
confidence_score = confidence.item()
all_probabilities = probabilities.cpu().numpy().flatten()
return predicted_class, confidence_score, all_probabilities
def get_confidence_color(confidence):
"""Get color class based on confidence score"""
if confidence >= 0.7:
return "confidence-high"
elif confidence >= 0.4:
return "confidence-medium"
else:
return "confidence-low"
def main():
# Load model
model, device = load_model()
# File uploader
st.header("Garbage Classification")
uploaded_file = st.file_uploader(
"Choose an image file",
type=["jpg", "jpeg", "png"],
)
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns([1, 1])
with col1:
st.image(image, caption="Uploaded Image", use_container_width=True)
# Make prediction
with st.spinner("πŸ” Analyzing image..."):
predicted_class, confidence, probabilities = predict_image(
image, model, device
)
sorted_indices = np.argsort(probabilities)[::-1]
container = col2.container(border=True)
for i, idx in enumerate(sorted_indices):
class_name = CLASS_NAMES[idx]
prob = probabilities[idx]
container.write(f"{class_name.title()}: {prob:.1%}")
if __name__ == "__main__":
main()