import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import timm from PIL import Image from torchvision import transforms import numpy as np import matplotlib.pyplot as plt import io import requests import tempfile import os # Set page config st.set_page_config( page_title="Dog Breed Classifier", page_icon="đ", layout="wide" ) # Default model URL DEFAULT_MODEL_URL = "https://huggingface.co/Alamgirapi/dog-breed-convnext-classifier/resolve/main/model.pth" # Device setup @st.cache_resource def setup_device_and_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model model = timm.create_model('convnext_base', pretrained=True) # Define label names label_names = ['beagle', 'bulldog', 'dalmatian', 'german-shepherd', 'husky', 'poodle', 'rottweiler'] # Replace head with proper flattening model.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(model.head.in_features, len(label_names)) ) model = model.to(device) # Define transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return device, model, label_names, transform @st.cache_resource def download_and_load_model(_model, device): """Download and load model weights from Hugging Face""" try: with st.spinner("Downloading model from Hugging Face..."): # Download the model file response = requests.get(DEFAULT_MODEL_URL) response.raise_for_status() # Save to temporary file with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file: tmp_file.write(response.content) tmp_model_path = tmp_file.name # Load the model weights _model.load_state_dict(torch.load(tmp_model_path, map_location=device)) _model.eval() # Clean up temporary file os.unlink(tmp_model_path) return True except Exception as e: st.error(f"Error downloading/loading model: {str(e)}") return False def predict_image(image, model, transform, label_names, device, topk=3): """Make predictions on uploaded image""" # Transform image if image.mode != 'RGB': image = image.convert('RGB') img_tensor = transform(image).unsqueeze(0).to(device) # Predict model.eval() with torch.no_grad(): outputs = model(img_tensor) probs = F.softmax(outputs, dim=1) top_probs, top_idxs = torch.topk(probs, k=topk) # Convert to CPU for display top_probs = top_probs[0].cpu().numpy() top_idxs = top_idxs[0].cpu().numpy() # Build prediction results predictions = [] for idx, prob in zip(top_idxs, top_probs): predictions.append({ 'breed': label_names[idx], 'confidence': prob * 100 }) return predictions def create_prediction_chart(predictions): """Create a horizontal bar chart for predictions""" breeds = [pred['breed'].replace('-', ' ').title() for pred in predictions] confidences = [float(pred['confidence']) for pred in predictions] # Convert to Python float fig, ax = plt.subplots(figsize=(10, 6)) bars = ax.barh(breeds, confidences, color=['#1f77b4', '#ff7f0e', '#2ca02c']) ax.set_xlabel('Confidence (%)') ax.set_title('Top 3 Breed Predictions') ax.set_xlim(0, 100) # Add percentage labels on bars for i, (bar, conf) in enumerate(zip(bars, confidences)): ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, f'{conf:.1f}%', va='center') plt.tight_layout() return fig # Main app def main(): st.title("đ Dog Breed Classifier") st.write("Upload an image of a dog to identify its breed!") # Initialize model and device device, model, label_names, transform = setup_device_and_model() # Download and load the model automatically model_loaded = download_and_load_model(model, device) if model_loaded: st.success("â Model loaded successfully from Hugging Face!") else: st.error("â Failed to load model. Please refresh the page and try again.") return # Main content col1, col2 = st.columns([1, 1]) with col1: st.header("Upload Image") uploaded_file = st.file_uploader( "Choose an image file", type=['jpg', 'jpeg', 'png'], help="Upload a clear image of a dog for best results" ) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) # Show image details st.write(f"**Image Size:** {image.size}") st.write(f"**Image Mode:** {image.mode}") with col2: st.header("Predictions") if uploaded_file is not None: try: with st.spinner("Analyzing image..."): # Make predictions predictions = predict_image(image, model, transform, label_names, device) # Display results st.success("đ Analysis Complete!") # Show top prediction prominently top_breed = predictions[0]['breed'].replace('-', ' ').title() top_confidence = float(predictions[0]['confidence']) # Convert to Python float st.markdown(f"""