Spaces:
Sleeping
Sleeping
File size: 4,327 Bytes
9e6a96a 4efcfc7 9e6a96a 6abf0ea b1205cc ff2fe1c 9e6a96a 3ff5d42 4d6d4c8 9e6a96a 99d4d0e 9e6a96a 9a18613 9e6a96a b68bda0 9e6a96a ff2fe1c 9e6a96a d068ca1 9e6a96a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import sys
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
from resnet_model import MonkeyResNet
from data_loader import get_data_loaders
import io
# Ensure the parent directory is in the system path for module imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# Set Streamlit page configuration
st.set_page_config(page_title="ApexID: Monkey Species Classifier", layout="wide")
# Constants for model path and class labels
MODEL_PATH = os.path.join(os.path.dirname(__file__), "monkey_resnet.pth")
CLASS_NAMES = ['n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9']
LABEL_MAP = {
'n0': 'Alouatta Palliata',
'n1': 'Erythrocebus Patas',
'n2': 'Cacajao Calvus',
'n3': 'Macaca Fuscata',
'n4': 'Cebuella Pygmea',
'n5': 'Cebus Capucinus',
'n6': 'Mico Argentatus',
'n7': 'Saimiri Sciureus',
'n8': 'Aotus Nigriceps',
'n9': 'Trachypithecus Johnii'
}
# Load model with caching to avoid reloading every time
@st.cache_resource
def load_model():
model = MonkeyResNet(num_classes=len(CLASS_NAMES))
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
return model
model = load_model()
# Image preprocessing to match model input requirements
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Add batch dimension after transforming
return transform(image).unsqueeze(0)
# Title for the app
st.title("ApexID: Monkey Species Classifier")
# Create tabs for project info, classification, and model details
tab1, tab2, tab3 = st.tabs(["About the Project", "Image Classification", "Model Details"])
# Tab 1: Project description
with tab1:
st.header("About This Project")
st.write("""
This project uses a deep learning model based on ResNet18 with transfer learning to classify images of ten monkey species. It applies convolutional neural networks (CNNs) for accurate image recognition and is designed for tasks like education, wildlife monitoring, and zoo record management.
Key Points:
- Image classification using CNN and transfer learning
- Built with PyTorch for model training
- Streamlit used for a user-friendly interface
""")
# Tab 2: Image classification interface
with tab2:
st.header("Classify a Monkey Image")
st.markdown("""
Steps to classify:
1. Upload a clear monkey image.
2. Supported file types: jpg, png, jpeg.
3. See prediction below after uploading.
""")
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
raw_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
st.image(image, caption="Uploaded Image", width=300)
input_tensor = preprocess_image(image)
input_tensor = preprocess_image(image)
# Add spinner to indicate loading
with st.spinner("Classifying... Please wait."):
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
predicted_label = CLASS_NAMES[predicted.item()]
species_name = LABEL_MAP[predicted_label]
st.success(f"Predicted Monkey Species: {species_name}")
# Tab 3: Model details and performance
with tab3:
st.header("Model Information")
st.markdown("""
- Model architecture: ResNet18 using transfer learning
- Framework: PyTorch
- Final training accuracy: 90.88%
- Final validation loss: 2.44
- Test accuracy: 92%
""")
# Show training accuracy plot
if os.path.exists("plots/accuracy_plot.png"):
st.image("plots/accuracy_plot.png", caption="Training and Validation Accuracy")
# Show training loss plot
if os.path.exists("plots/loss_plot.png"):
st.image("plots/loss_plot.png", caption="Training and Validation Loss")
# Show confusion matrix
if os.path.exists("confusion_matrix.png"):
st.image("confusion_matrix.png", caption="Confusion Matrix")
|