Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import timm | |
from PIL import Image | |
from torchvision import transforms | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import io | |
import requests | |
import tempfile | |
import os | |
# Set page config | |
st.set_page_config( | |
page_title="Dog Breed Classifier", | |
page_icon="π", | |
layout="wide" | |
) | |
# Default model URL | |
DEFAULT_MODEL_URL = "https://huggingface.co/Alamgirapi/dog-breed-convnext-classifier/resolve/main/model.pth" | |
# Device setup | |
def setup_device_and_model(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize model | |
model = timm.create_model('convnext_base', pretrained=True) | |
# Define label names | |
label_names = ['beagle', 'bulldog', 'dalmatian', 'german-shepherd', 'husky', 'poodle', 'rottweiler'] | |
# Replace head with proper flattening | |
model.head = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Linear(model.head.in_features, len(label_names)) | |
) | |
model = model.to(device) | |
# Define transform | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return device, model, label_names, transform | |
def download_and_load_model(_model, device): | |
"""Download and load model weights from Hugging Face""" | |
try: | |
with st.spinner("Downloading model from Hugging Face..."): | |
# Download the model file | |
response = requests.get(DEFAULT_MODEL_URL) | |
response.raise_for_status() | |
# Save to temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file: | |
tmp_file.write(response.content) | |
tmp_model_path = tmp_file.name | |
# Load the model weights | |
_model.load_state_dict(torch.load(tmp_model_path, map_location=device)) | |
_model.eval() | |
# Clean up temporary file | |
os.unlink(tmp_model_path) | |
return True | |
except Exception as e: | |
st.error(f"Error downloading/loading model: {str(e)}") | |
return False | |
def predict_image(image, model, transform, label_names, device, topk=3): | |
"""Make predictions on uploaded image""" | |
# Transform image | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
# Predict | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
probs = F.softmax(outputs, dim=1) | |
top_probs, top_idxs = torch.topk(probs, k=topk) | |
# Convert to CPU for display | |
top_probs = top_probs[0].cpu().numpy() | |
top_idxs = top_idxs[0].cpu().numpy() | |
# Build prediction results | |
predictions = [] | |
for idx, prob in zip(top_idxs, top_probs): | |
predictions.append({ | |
'breed': label_names[idx], | |
'confidence': prob * 100 | |
}) | |
return predictions | |
def create_prediction_chart(predictions): | |
"""Create a horizontal bar chart for predictions""" | |
breeds = [pred['breed'].replace('-', ' ').title() for pred in predictions] | |
confidences = [float(pred['confidence']) for pred in predictions] # Convert to Python float | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
bars = ax.barh(breeds, confidences, color=['#1f77b4', '#ff7f0e', '#2ca02c']) | |
ax.set_xlabel('Confidence (%)') | |
ax.set_title('Top 3 Breed Predictions') | |
ax.set_xlim(0, 100) | |
# Add percentage labels on bars | |
for i, (bar, conf) in enumerate(zip(bars, confidences)): | |
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, | |
f'{conf:.1f}%', va='center') | |
plt.tight_layout() | |
return fig | |
# Main app | |
def main(): | |
st.title("π Dog Breed Classifier") | |
st.write("Upload an image of a dog to identify its breed!") | |
# Initialize model and device | |
device, model, label_names, transform = setup_device_and_model() | |
# Download and load the model automatically | |
model_loaded = download_and_load_model(model, device) | |
if model_loaded: | |
st.success("β Model loaded successfully from Hugging Face!") | |
else: | |
st.error("β Failed to load model. Please refresh the page and try again.") | |
return | |
# Main content | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.header("Upload Image") | |
uploaded_file = st.file_uploader( | |
"Choose an image file", | |
type=['jpg', 'jpeg', 'png'], | |
help="Upload a clear image of a dog for best results" | |
) | |
if uploaded_file is not None: | |
# Display uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Show image details | |
st.write(f"**Image Size:** {image.size}") | |
st.write(f"**Image Mode:** {image.mode}") | |
with col2: | |
st.header("Predictions") | |
if uploaded_file is not None: | |
try: | |
with st.spinner("Analyzing image..."): | |
# Make predictions | |
predictions = predict_image(image, model, transform, label_names, device) | |
# Display results | |
st.success("π Analysis Complete!") | |
# Show top prediction prominently | |
top_breed = predictions[0]['breed'].replace('-', ' ').title() | |
top_confidence = float(predictions[0]['confidence']) # Convert to Python float | |
st.markdown(f""" | |
<div style="background-color: #f0f8ff; padding: 20px; border-radius: 10px; border-left: 5px solid #1f77b4;"> | |
<h3 style="color: #1f77b4; margin: 0;">π Most Likely Breed</h3> | |
<h2 style="margin: 5px 0;">{top_breed}</h2> | |
<h4 style="color: #666; margin: 0;">Confidence: {top_confidence:.1f}%</h4> | |
</div> | |
""", unsafe_allow_html=True) | |
st.write("") # Add some space | |
# Show all predictions | |
st.subheader("All Predictions:") | |
for i, pred in enumerate(predictions): | |
breed = pred['breed'].replace('-', ' ').title() | |
confidence = float(pred['confidence']) # Convert numpy float32 to Python float | |
# Create progress bar | |
st.write(f"**{i+1}. {breed}**") | |
st.progress(confidence/100) | |
st.write(f"Confidence: {confidence:.2f}%") | |
st.write("") | |
# Show chart | |
st.subheader("Prediction Chart:") | |
fig = create_prediction_chart(predictions) | |
st.pyplot(fig) | |
except Exception as e: | |
st.error(f"Error during prediction: {str(e)}") | |
else: | |
st.info("π€ Please upload an image to start classification.") | |
# Information section | |
with st.expander("βΉοΈ About this App"): | |
st.write(""" | |
This app uses a ConvNeXt-Base model trained to classify dog breeds among: | |
- Beagle | |
- Bulldog | |
- Dalmatian | |
- German Shepherd | |
- Husky | |
- Poodle | |
- Rottweiler | |
**How to use:** | |
1. The model is automatically loaded from Hugging Face | |
2. Upload a clear image of a dog | |
3. View the top 3 breed predictions with confidence scores | |
**Tips for better results:** | |
- Use high-quality, well-lit images | |
- Ensure the dog is clearly visible in the image | |
- Avoid images with multiple dogs | |
""") | |
# Technical details | |
with st.expander("π§ Technical Details"): | |
st.write(f""" | |
- **Device:** {device} | |
- **Model:** ConvNeXt-Base | |
- **Model Source:** Hugging Face (Alamgirapi/dog-breed-convnext-classifier) | |
- **Input Size:** 224x224 pixels | |
- **Classes:** {len(label_names)} | |
- **Framework:** PyTorch + Streamlit | |
""") | |
if __name__ == "__main__": | |
main() |