Alamgirapi's picture
Upload app.py
72c86b8 verified
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from PIL import Image
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import io
import requests
import tempfile
import os
# Set page config
st.set_page_config(
page_title="Dog Breed Classifier",
page_icon="πŸ•",
layout="wide"
)
# Default model URL
DEFAULT_MODEL_URL = "https://huggingface.co/Alamgirapi/dog-breed-convnext-classifier/resolve/main/model.pth"
# Device setup
@st.cache_resource
def setup_device_and_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = timm.create_model('convnext_base', pretrained=True)
# Define label names
label_names = ['beagle', 'bulldog', 'dalmatian', 'german-shepherd', 'husky', 'poodle', 'rottweiler']
# Replace head with proper flattening
model.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(model.head.in_features, len(label_names))
)
model = model.to(device)
# Define transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return device, model, label_names, transform
@st.cache_resource
def download_and_load_model(_model, device):
"""Download and load model weights from Hugging Face"""
try:
with st.spinner("Downloading model from Hugging Face..."):
# Download the model file
response = requests.get(DEFAULT_MODEL_URL)
response.raise_for_status()
# Save to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
tmp_file.write(response.content)
tmp_model_path = tmp_file.name
# Load the model weights
_model.load_state_dict(torch.load(tmp_model_path, map_location=device))
_model.eval()
# Clean up temporary file
os.unlink(tmp_model_path)
return True
except Exception as e:
st.error(f"Error downloading/loading model: {str(e)}")
return False
def predict_image(image, model, transform, label_names, device, topk=3):
"""Make predictions on uploaded image"""
# Transform image
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
model.eval()
with torch.no_grad():
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_probs, top_idxs = torch.topk(probs, k=topk)
# Convert to CPU for display
top_probs = top_probs[0].cpu().numpy()
top_idxs = top_idxs[0].cpu().numpy()
# Build prediction results
predictions = []
for idx, prob in zip(top_idxs, top_probs):
predictions.append({
'breed': label_names[idx],
'confidence': prob * 100
})
return predictions
def create_prediction_chart(predictions):
"""Create a horizontal bar chart for predictions"""
breeds = [pred['breed'].replace('-', ' ').title() for pred in predictions]
confidences = [float(pred['confidence']) for pred in predictions] # Convert to Python float
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.barh(breeds, confidences, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
ax.set_xlabel('Confidence (%)')
ax.set_title('Top 3 Breed Predictions')
ax.set_xlim(0, 100)
# Add percentage labels on bars
for i, (bar, conf) in enumerate(zip(bars, confidences)):
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
f'{conf:.1f}%', va='center')
plt.tight_layout()
return fig
# Main app
def main():
st.title("πŸ• Dog Breed Classifier")
st.write("Upload an image of a dog to identify its breed!")
# Initialize model and device
device, model, label_names, transform = setup_device_and_model()
# Download and load the model automatically
model_loaded = download_and_load_model(model, device)
if model_loaded:
st.success("βœ… Model loaded successfully from Hugging Face!")
else:
st.error("❌ Failed to load model. Please refresh the page and try again.")
return
# Main content
col1, col2 = st.columns([1, 1])
with col1:
st.header("Upload Image")
uploaded_file = st.file_uploader(
"Choose an image file",
type=['jpg', 'jpeg', 'png'],
help="Upload a clear image of a dog for best results"
)
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Show image details
st.write(f"**Image Size:** {image.size}")
st.write(f"**Image Mode:** {image.mode}")
with col2:
st.header("Predictions")
if uploaded_file is not None:
try:
with st.spinner("Analyzing image..."):
# Make predictions
predictions = predict_image(image, model, transform, label_names, device)
# Display results
st.success("πŸŽ‰ Analysis Complete!")
# Show top prediction prominently
top_breed = predictions[0]['breed'].replace('-', ' ').title()
top_confidence = float(predictions[0]['confidence']) # Convert to Python float
st.markdown(f"""
<div style="background-color: #f0f8ff; padding: 20px; border-radius: 10px; border-left: 5px solid #1f77b4;">
<h3 style="color: #1f77b4; margin: 0;">πŸ† Most Likely Breed</h3>
<h2 style="margin: 5px 0;">{top_breed}</h2>
<h4 style="color: #666; margin: 0;">Confidence: {top_confidence:.1f}%</h4>
</div>
""", unsafe_allow_html=True)
st.write("") # Add some space
# Show all predictions
st.subheader("All Predictions:")
for i, pred in enumerate(predictions):
breed = pred['breed'].replace('-', ' ').title()
confidence = float(pred['confidence']) # Convert numpy float32 to Python float
# Create progress bar
st.write(f"**{i+1}. {breed}**")
st.progress(confidence/100)
st.write(f"Confidence: {confidence:.2f}%")
st.write("")
# Show chart
st.subheader("Prediction Chart:")
fig = create_prediction_chart(predictions)
st.pyplot(fig)
except Exception as e:
st.error(f"Error during prediction: {str(e)}")
else:
st.info("πŸ“€ Please upload an image to start classification.")
# Information section
with st.expander("ℹ️ About this App"):
st.write("""
This app uses a ConvNeXt-Base model trained to classify dog breeds among:
- Beagle
- Bulldog
- Dalmatian
- German Shepherd
- Husky
- Poodle
- Rottweiler
**How to use:**
1. The model is automatically loaded from Hugging Face
2. Upload a clear image of a dog
3. View the top 3 breed predictions with confidence scores
**Tips for better results:**
- Use high-quality, well-lit images
- Ensure the dog is clearly visible in the image
- Avoid images with multiple dogs
""")
# Technical details
with st.expander("πŸ”§ Technical Details"):
st.write(f"""
- **Device:** {device}
- **Model:** ConvNeXt-Base
- **Model Source:** Hugging Face (Alamgirapi/dog-breed-convnext-classifier)
- **Input Size:** 224x224 pixels
- **Classes:** {len(label_names)}
- **Framework:** PyTorch + Streamlit
""")
if __name__ == "__main__":
main()