Alamgirapi's picture
Upload app.py
b7c34ef verified
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import requests
import io
from timm import create_model
# Set page config
st.set_page_config(
page_title="Sports Ball Classifier",
page_icon="πŸ€",
layout="wide"
)
# Custom ConvNeXt model definition (in case the saved model uses a different architecture)
class ConvNeXtBlock(nn.Module):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + x
return x
class CustomConvNeXt(nn.Module):
def __init__(self, num_classes=15):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=4, stride=4),
nn.LayerNorm([96, 56, 56], eps=1e-6)
)
# Stage 1
self.stage1 = nn.Sequential(*[ConvNeXtBlock(96) for _ in range(3)])
# Downsample 1
self.downsample1 = nn.Sequential(
nn.LayerNorm([96, 56, 56], eps=1e-6),
nn.Conv2d(96, 192, kernel_size=2, stride=2)
)
# Stage 2
self.stage2 = nn.Sequential(*[ConvNeXtBlock(192) for _ in range(3)])
# Downsample 2
self.downsample2 = nn.Sequential(
nn.LayerNorm([192, 28, 28], eps=1e-6),
nn.Conv2d(192, 384, kernel_size=2, stride=2)
)
# Stage 3
self.stage3 = nn.Sequential(*[ConvNeXtBlock(384) for _ in range(9)])
# Downsample 3
self.downsample3 = nn.Sequential(
nn.LayerNorm([384, 14, 14], eps=1e-6),
nn.Conv2d(384, 768, kernel_size=2, stride=2)
)
# Stage 4
self.stage4 = nn.Sequential(*[ConvNeXtBlock(768) for _ in range(3)])
# Head
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.norm = nn.LayerNorm(768, eps=1e-6)
self.head = nn.Linear(768, num_classes)
def forward(self, x):
x = self.stem(x)
x = self.stage1(x)
x = self.downsample1(x)
x = self.stage2(x)
x = self.downsample2(x)
x = self.stage3(x)
x = self.downsample3(x)
x = self.stage4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.norm(x)
x = self.head(x)
return x
# Cache the model loading to avoid reloading on every interaction
@st.cache_resource
def load_model():
"""Load the pre-trained ViT model for sports ball classification"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
# Download model weights from Hugging Face
model_url = "https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier/resolve/main/model.pth"
response = requests.get(model_url)
if response.status_code != 200:
raise Exception(f"Failed to download model: HTTP {response.status_code}")
model_state = torch.load(io.BytesIO(response.content), map_location=device)
# Inspect the state dict to understand the model structure
sample_keys = list(model_state.keys())[:10]
# Try Vision Transformer models (this is likely what was used)
vit_models_to_try = [
("vit_base_patch16_224", lambda: create_model('vit_base_patch16_224', pretrained=False, num_classes=15)),
("vit_small_patch16_224", lambda: create_model('vit_small_patch16_224', pretrained=False, num_classes=15)),
("vit_tiny_patch16_224", lambda: create_model('vit_tiny_patch16_224', pretrained=False, num_classes=15)),
("vit_large_patch16_224", lambda: create_model('vit_large_patch16_224', pretrained=False, num_classes=15)),
("vit_base_patch32_224", lambda: create_model('vit_base_patch32_224', pretrained=False, num_classes=15)),
]
st.info("Trying Vision Transformer (ViT) models...")
for model_name, model_func in vit_models_to_try:
try:
model = model_func()
model.load_state_dict(model_state)
model.eval()
model.to(device)
st.success(f"βœ… Successfully loaded model using: {model_name}")
return model, device
except Exception as e:
st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...")
continue
# Try ConvNeXt models as fallback
convnext_models_to_try = [
("convnext_tiny", lambda: create_model('convnext_tiny', pretrained=False, num_classes=15)),
("convnext_small", lambda: create_model('convnext_small', pretrained=False, num_classes=15)),
("convnext_base", lambda: create_model('convnext_base', pretrained=False, num_classes=15)),
]
st.info("Trying ConvNeXt models as fallback...")
for model_name, model_func in convnext_models_to_try:
try:
model = model_func()
model.load_state_dict(model_state)
model.eval()
model.to(device)
st.success(f"βœ… Successfully loaded model using: {model_name}")
return model, device
except Exception as e:
st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...")
continue
# Try other common models
other_models_to_try = [
("resnet50", lambda: create_model('resnet50', pretrained=False, num_classes=15)),
("efficientnet_b0", lambda: create_model('efficientnet_b0', pretrained=False, num_classes=15)),
("mobilenetv3_large_100", lambda: create_model('mobilenetv3_large_100', pretrained=False, num_classes=15)),
]
st.info("Trying other model architectures...")
for model_name, model_func in other_models_to_try:
try:
model = model_func()
model.load_state_dict(model_state)
model.eval()
model.to(device)
st.success(f"βœ… Successfully loaded model using: {model_name}")
return model, device
except Exception as e:
st.warning(f"❌ Failed to load with {model_name}: {str(e)[:100]}...")
continue
# If all fail, try loading with strict=False and show detailed info
st.info("Attempting to load with strict=False...")
try:
# Try with the most common ViT model first
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=15)
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
if missing_keys:
st.warning(f"⚠️ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...")
if unexpected_keys:
st.warning(f"⚠️ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...")
model.eval()
model.to(device)
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
st.error("⚠️ Model loaded with mismatched weights - predictions will likely be unreliable!")
st.info("πŸ’‘ The saved model might have been trained with a different architecture.")
else:
st.success("βœ… Model loaded successfully with strict=False")
return model, device
except Exception as e:
st.error(f"❌ Failed to load model with all methods. Error: {str(e)}")
st.info("πŸ’‘ Try checking the model file or re-training with a compatible architecture.")
return None, device
except Exception as e:
st.error(f"❌ Error downloading or loading model: {str(e)}")
return None, device
def get_transform():
"""Define image preprocessing transforms"""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict_image(image, model, device, transform, label_names, topk=5):
"""Make predictions on uploaded image"""
# Transform image
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_probs, top_idxs = torch.topk(probs, k=topk)
# Convert to CPU for display
top_probs = top_probs[0].cpu().numpy()
top_idxs = top_idxs[0].cpu().numpy()
return top_probs, top_idxs
def main():
st.title("πŸ€ Sports Ball Classifier")
st.markdown("Upload an image of a sports ball and get AI-powered predictions!")
# Define label names
label_names = [
'american_football', 'baseball', 'basketball', 'billiard_ball',
'bowling_ball', 'cricket_ball', 'football', 'golf_ball',
'hockey_ball', 'hockey_puck', 'rugby_ball', 'shuttlecock',
'table_tennis_ball', 'tennis_ball', 'volleyball'
]
# Load model
with st.spinner("Loading model..."):
model, device = load_model()
if model is None:
st.error("Failed to load model. Please try again later.")
return
st.success(f"Model loaded successfully! Using device: {device}")
# Get image transform
transform = get_transform()
# Create two columns
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Upload Image")
uploaded_file = st.file_uploader(
"Choose an image...",
type=['png', 'jpg', 'jpeg'],
help="Upload a clear image of a sports ball for best results"
)
# Number of top predictions to show
topk = st.slider("Number of predictions to show:", 1, 10, 5)
with col2:
st.subheader("Predictions")
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# Make predictions
with st.spinner("Analyzing image..."):
try:
top_probs, top_idxs = predict_image(
image, model, device, transform, label_names, topk
)
# Show original top prediction prominently
top_confidence = float(top_probs[0] * 100)
top_label = label_names[top_idxs[0]].replace('_', ' ').title()
if top_confidence > 70:
color = "🟒"
elif top_confidence > 40:
color = "🟑"
else:
color = "πŸ”΄"
st.success(f"{color} **Primary Prediction: {top_label}** ({top_confidence:.2f}%)")
st.progress(float(top_confidence / 100))
# Show top 3 high confidence predictions
st.subheader("Top 3 Predictions:")
for i in range(min(3, len(top_probs))):
confidence = float(top_probs[i] * 100)
label = label_names[top_idxs[i]].replace('_', ' ').title()
# Color coding based on confidence
if confidence > 70:
color = "🟒"
elif confidence > 40:
color = "🟑"
else:
color = "πŸ”΄"
st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%")
# Progress bar for confidence (convert to Python float)
st.progress(float(confidence / 100))
# Show all predictions if user wants more
if topk > 3:
with st.expander(f"See all {topk} predictions"):
for i in range(3, len(top_probs)):
confidence = float(top_probs[i] * 100)
label = label_names[top_idxs[i]].replace('_', ' ').title()
if confidence > 70:
color = "🟒"
elif confidence > 40:
color = "🟑"
else:
color = "πŸ”΄"
st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%")
st.progress(float(confidence / 100))
# Show detailed results in expandable section
with st.expander("Detailed Results"):
fig, ax = plt.subplots(figsize=(10, 6))
labels = [label_names[idx].replace('_', ' ').title() for idx in top_idxs]
probabilities = [float(prob * 100) for prob in top_probs] # Convert to Python float
bars = ax.barh(labels[::-1], probabilities[::-1])
ax.set_xlabel('Confidence (%)')
ax.set_title(f'Top {topk} Predictions')
ax.set_xlim(0, 100)
# Color bars based on confidence
for bar, prob in zip(bars, probabilities[::-1]):
if prob > 70:
bar.set_color('#4CAF50') # Green
elif prob > 40:
bar.set_color('#FF9800') # Orange
else:
bar.set_color('#F44336') # Red
# Add percentage labels on bars
for i, (bar, prob) in enumerate(zip(bars, probabilities[::-1])):
ax.text(float(prob) + 1, bar.get_y() + bar.get_height()/2,
f'{float(prob):.1f}%', va='center')
plt.tight_layout()
st.pyplot(fig)
except Exception as e:
st.error(f"Error during prediction: {str(e)}")
else:
st.info("πŸ‘† Please upload an image to get started!")
# Additional information
st.markdown("---")
st.subheader("Supported Sports Balls")
# Display supported categories in a nice grid
cols = st.columns(5)
for i, label in enumerate(label_names):
with cols[i % 5]:
st.write(f"β€’ {label.replace('_', ' ').title()}")
st.markdown("---")
st.markdown("""
**About this model:**
- Built using ConvNeXt architecture
- Trained to classify 15 different types of sports balls
- Model weights from: [Alamgirapi/sports-ball-convnext-classifier](https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier)
**Tips for best results:**
- Use clear, well-lit images
- Ensure the ball is the main subject
- Avoid cluttered backgrounds when possible
""")
if __name__ == "__main__":
main()