File size: 4,327 Bytes
9e6a96a
4efcfc7
9e6a96a
 
 
 
6abf0ea
b1205cc
ff2fe1c
9e6a96a
3ff5d42
4d6d4c8
 
 
9e6a96a
 
 
 
99d4d0e
9e6a96a
 
9a18613
 
 
 
 
 
 
 
 
 
9e6a96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b68bda0
 
9e6a96a
ff2fe1c
9e6a96a
 
d068ca1
 
 
 
 
 
 
 
 
 
9e6a96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import sys
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
from resnet_model import MonkeyResNet
from data_loader import get_data_loaders
import io


# Ensure the parent directory is in the system path for module imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# Set Streamlit page configuration
st.set_page_config(page_title="ApexID: Monkey Species Classifier", layout="wide")

# Constants for model path and class labels
MODEL_PATH = os.path.join(os.path.dirname(__file__), "monkey_resnet.pth")
CLASS_NAMES = ['n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9']
LABEL_MAP = {
    'n0': 'Alouatta Palliata',
    'n1': 'Erythrocebus Patas',
    'n2': 'Cacajao Calvus',
    'n3': 'Macaca Fuscata',
    'n4': 'Cebuella Pygmea',
    'n5': 'Cebus Capucinus',
    'n6': 'Mico Argentatus',
    'n7': 'Saimiri Sciureus',
    'n8': 'Aotus Nigriceps',
    'n9': 'Trachypithecus Johnii'
}

# Load model with caching to avoid reloading every time
@st.cache_resource
def load_model():
    model = MonkeyResNet(num_classes=len(CLASS_NAMES))
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    model.eval()
    return model

model = load_model()

# Image preprocessing to match model input requirements
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # Add batch dimension after transforming
    return transform(image).unsqueeze(0)

# Title for the app
st.title("ApexID: Monkey Species Classifier")

# Create tabs for project info, classification, and model details
tab1, tab2, tab3 = st.tabs(["About the Project", "Image Classification", "Model Details"])

# Tab 1: Project description
with tab1:
    st.header("About This Project")
    st.write("""
    This project uses a deep learning model based on ResNet18 with transfer learning to classify images of ten monkey species. It applies convolutional neural networks (CNNs) for accurate image recognition and is designed for tasks like education, wildlife monitoring, and zoo record management.

    Key Points:
    - Image classification using CNN and transfer learning
    - Built with PyTorch for model training
    - Streamlit used for a user-friendly interface
    """)

# Tab 2: Image classification interface
with tab2:
    st.header("Classify a Monkey Image")
    st.markdown("""
    Steps to classify:
    1. Upload a clear monkey image.
    2. Supported file types: jpg, png, jpeg.
    3. See prediction below after uploading.
    """)

    uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])

    if uploaded_file is not None:
        raw_bytes = uploaded_file.read()
        image = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
        st.image(image, caption="Uploaded Image", width=300)
        input_tensor = preprocess_image(image)

        input_tensor = preprocess_image(image)

        # Add spinner to indicate loading
        with st.spinner("Classifying... Please wait."):
            with torch.no_grad():
                outputs = model(input_tensor)
                _, predicted = torch.max(outputs, 1)
                predicted_label = CLASS_NAMES[predicted.item()]
                species_name = LABEL_MAP[predicted_label]

        st.success(f"Predicted Monkey Species: {species_name}")

# Tab 3: Model details and performance
with tab3:
    st.header("Model Information")
    st.markdown("""
    - Model architecture: ResNet18 using transfer learning  
    - Framework: PyTorch  
    - Final training accuracy: 90.88%  
    - Final validation loss: 2.44  
    - Test accuracy: 92%
    """)

    # Show training accuracy plot
    if os.path.exists("plots/accuracy_plot.png"):
        st.image("plots/accuracy_plot.png", caption="Training and Validation Accuracy")

    # Show training loss plot
    if os.path.exists("plots/loss_plot.png"):
        st.image("plots/loss_plot.png", caption="Training and Validation Loss")

    # Show confusion matrix
    if os.path.exists("confusion_matrix.png"):
        st.image("confusion_matrix.png", caption="Confusion Matrix")