|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
|
|
st.set_page_config(page_title="Garbage Classification") |
|
|
|
|
|
|
|
class SimpleCNN(nn.Module): |
|
def __init__(self, num_classes, input_channels=3): |
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=0) |
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=0) |
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=0) |
|
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=0) |
|
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.flatten = nn.Flatten() |
|
|
|
|
|
self.fc1 = nn.Linear(256 * 12 * 12, 512) |
|
self.dropout1 = nn.Dropout(0.5) |
|
|
|
self.fc2 = nn.Linear(512, 512) |
|
self.dropout2 = nn.Dropout(0.5) |
|
|
|
self.fc3 = nn.Linear(512, num_classes) |
|
|
|
def forward(self, x): |
|
|
|
x = F.relu(self.conv1(x)) |
|
x = self.pool1(x) |
|
|
|
x = F.relu(self.conv2(x)) |
|
x = self.pool2(x) |
|
|
|
x = F.relu(self.conv3(x)) |
|
x = self.pool3(x) |
|
|
|
x = F.relu(self.conv4(x)) |
|
x = self.pool4(x) |
|
|
|
|
|
x = self.flatten(x) |
|
x = F.relu(self.fc1(x)) |
|
x = self.dropout1(x) |
|
|
|
x = F.relu(self.fc2(x)) |
|
x = self.dropout2(x) |
|
|
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
|
|
CLASS_NAMES = [ |
|
"battery", |
|
"biological", |
|
"cardboard", |
|
"clothes", |
|
"glass", |
|
"metal", |
|
"paper", |
|
"plastic", |
|
"shoes", |
|
"trash", |
|
] |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
"""Load the trained model""" |
|
device = torch.device("cpu") |
|
model = SimpleCNN(num_classes=10) |
|
model = nn.DataParallel(model) |
|
|
|
try: |
|
model.load_state_dict(torch.load("best_model.pth", map_location=device)) |
|
model.eval() |
|
return model, device |
|
except Exception as e: |
|
st.error(f"Error loading model: {e}") |
|
return None, device |
|
|
|
|
|
def preprocess_image(image): |
|
"""Preprocess uploaded image""" |
|
transform = T.Compose( |
|
[ |
|
T.Resize(224), |
|
T.CenterCrop(224), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
return image_tensor |
|
|
|
|
|
def predict_image(image, model, device): |
|
"""Make prediction on image""" |
|
|
|
input_tensor = preprocess_image(image).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_tensor) |
|
probabilities = F.softmax(outputs, dim=1) |
|
confidence, predicted_idx = torch.max(probabilities, 1) |
|
|
|
predicted_class = CLASS_NAMES[predicted_idx.item()] |
|
confidence_score = confidence.item() |
|
all_probabilities = probabilities.cpu().numpy().flatten() |
|
|
|
return predicted_class, confidence_score, all_probabilities |
|
|
|
|
|
def get_confidence_color(confidence): |
|
"""Get color class based on confidence score""" |
|
if confidence >= 0.7: |
|
return "confidence-high" |
|
elif confidence >= 0.4: |
|
return "confidence-medium" |
|
else: |
|
return "confidence-low" |
|
|
|
|
|
def main(): |
|
|
|
model, device = load_model() |
|
|
|
|
|
st.header("Garbage Classification") |
|
uploaded_file = st.file_uploader( |
|
"Choose an image file", |
|
type=["jpg", "jpeg", "png"], |
|
) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
with st.spinner("π Analyzing image..."): |
|
predicted_class, confidence, probabilities = predict_image( |
|
image, model, device |
|
) |
|
|
|
sorted_indices = np.argsort(probabilities)[::-1] |
|
|
|
container = col2.container(border=True) |
|
for i, idx in enumerate(sorted_indices): |
|
class_name = CLASS_NAMES[idx] |
|
prob = probabilities[idx] |
|
container.write(f"{class_name.title()}: {prob:.1%}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|