import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import timm from PIL import Image from torchvision import transforms import numpy as np import matplotlib.pyplot as plt import io import requests import tempfile import os # Set page config st.set_page_config( page_title="Dog Breed Classifier", page_icon="🐕", layout="wide" ) # Default model URL DEFAULT_MODEL_URL = "https://huggingface.co/Alamgirapi/dog-breed-convnext-classifier/resolve/main/model.pth" # Device setup @st.cache_resource def setup_device_and_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model model = timm.create_model('convnext_base', pretrained=True) # Define label names label_names = ['beagle', 'bulldog', 'dalmatian', 'german-shepherd', 'husky', 'poodle', 'rottweiler'] # Replace head with proper flattening model.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(model.head.in_features, len(label_names)) ) model = model.to(device) # Define transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return device, model, label_names, transform @st.cache_resource def download_and_load_model(_model, device): """Download and load model weights from Hugging Face""" try: with st.spinner("Downloading model from Hugging Face..."): # Download the model file response = requests.get(DEFAULT_MODEL_URL) response.raise_for_status() # Save to temporary file with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file: tmp_file.write(response.content) tmp_model_path = tmp_file.name # Load the model weights _model.load_state_dict(torch.load(tmp_model_path, map_location=device)) _model.eval() # Clean up temporary file os.unlink(tmp_model_path) return True except Exception as e: st.error(f"Error downloading/loading model: {str(e)}") return False def predict_image(image, model, transform, label_names, device, topk=3): """Make predictions on uploaded image""" # Transform image if image.mode != 'RGB': image = image.convert('RGB') img_tensor = transform(image).unsqueeze(0).to(device) # Predict model.eval() with torch.no_grad(): outputs = model(img_tensor) probs = F.softmax(outputs, dim=1) top_probs, top_idxs = torch.topk(probs, k=topk) # Convert to CPU for display top_probs = top_probs[0].cpu().numpy() top_idxs = top_idxs[0].cpu().numpy() # Build prediction results predictions = [] for idx, prob in zip(top_idxs, top_probs): predictions.append({ 'breed': label_names[idx], 'confidence': prob * 100 }) return predictions def create_prediction_chart(predictions): """Create a horizontal bar chart for predictions""" breeds = [pred['breed'].replace('-', ' ').title() for pred in predictions] confidences = [float(pred['confidence']) for pred in predictions] # Convert to Python float fig, ax = plt.subplots(figsize=(10, 6)) bars = ax.barh(breeds, confidences, color=['#1f77b4', '#ff7f0e', '#2ca02c']) ax.set_xlabel('Confidence (%)') ax.set_title('Top 3 Breed Predictions') ax.set_xlim(0, 100) # Add percentage labels on bars for i, (bar, conf) in enumerate(zip(bars, confidences)): ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, f'{conf:.1f}%', va='center') plt.tight_layout() return fig # Main app def main(): st.title("🐕 Dog Breed Classifier") st.write("Upload an image of a dog to identify its breed!") # Initialize model and device device, model, label_names, transform = setup_device_and_model() # Download and load the model automatically model_loaded = download_and_load_model(model, device) if model_loaded: st.success("✅ Model loaded successfully from Hugging Face!") else: st.error("❌ Failed to load model. Please refresh the page and try again.") return # Main content col1, col2 = st.columns([1, 1]) with col1: st.header("Upload Image") uploaded_file = st.file_uploader( "Choose an image file", type=['jpg', 'jpeg', 'png'], help="Upload a clear image of a dog for best results" ) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) # Show image details st.write(f"**Image Size:** {image.size}") st.write(f"**Image Mode:** {image.mode}") with col2: st.header("Predictions") if uploaded_file is not None: try: with st.spinner("Analyzing image..."): # Make predictions predictions = predict_image(image, model, transform, label_names, device) # Display results st.success("🎉 Analysis Complete!") # Show top prediction prominently top_breed = predictions[0]['breed'].replace('-', ' ').title() top_confidence = float(predictions[0]['confidence']) # Convert to Python float st.markdown(f"""

🏆 Most Likely Breed

{top_breed}

Confidence: {top_confidence:.1f}%

""", unsafe_allow_html=True) st.write("") # Add some space # Show all predictions st.subheader("All Predictions:") for i, pred in enumerate(predictions): breed = pred['breed'].replace('-', ' ').title() confidence = float(pred['confidence']) # Convert numpy float32 to Python float # Create progress bar st.write(f"**{i+1}. {breed}**") st.progress(confidence/100) st.write(f"Confidence: {confidence:.2f}%") st.write("") # Show chart st.subheader("Prediction Chart:") fig = create_prediction_chart(predictions) st.pyplot(fig) except Exception as e: st.error(f"Error during prediction: {str(e)}") else: st.info("📤 Please upload an image to start classification.") # Information section with st.expander("â„šī¸ About this App"): st.write(""" This app uses a ConvNeXt-Base model trained to classify dog breeds among: - Beagle - Bulldog - Dalmatian - German Shepherd - Husky - Poodle - Rottweiler **How to use:** 1. The model is automatically loaded from Hugging Face 2. Upload a clear image of a dog 3. View the top 3 breed predictions with confidence scores **Tips for better results:** - Use high-quality, well-lit images - Ensure the dog is clearly visible in the image - Avoid images with multiple dogs """) # Technical details with st.expander("🔧 Technical Details"): st.write(f""" - **Device:** {device} - **Model:** ConvNeXt-Base - **Model Source:** Hugging Face (Alamgirapi/dog-breed-convnext-classifier) - **Input Size:** 224x224 pixels - **Classes:** {len(label_names)} - **Framework:** PyTorch + Streamlit """) if __name__ == "__main__": main()