Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import requests | |
import io | |
from timm import create_model | |
# Set page config | |
st.set_page_config( | |
page_title="Sports Ball Classifier", | |
page_icon="π", | |
layout="wide" | |
) | |
# Custom ConvNeXt model definition (in case the saved model uses a different architecture) | |
class ConvNeXtBlock(nn.Module): | |
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): | |
super().__init__() | |
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) | |
self.norm = nn.LayerNorm(dim, eps=1e-6) | |
self.pwconv1 = nn.Linear(dim, 4 * dim) | |
self.act = nn.GELU() | |
self.pwconv2 = nn.Linear(4 * dim, dim) | |
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), | |
requires_grad=True) if layer_scale_init_value > 0 else None | |
def forward(self, x): | |
input = x | |
x = self.dwconv(x) | |
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) | |
x = self.norm(x) | |
x = self.pwconv1(x) | |
x = self.act(x) | |
x = self.pwconv2(x) | |
if self.gamma is not None: | |
x = self.gamma * x | |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) | |
x = input + x | |
return x | |
class CustomConvNeXt(nn.Module): | |
def __init__(self, num_classes=15): | |
super().__init__() | |
self.stem = nn.Sequential( | |
nn.Conv2d(3, 96, kernel_size=4, stride=4), | |
nn.LayerNorm([96, 56, 56], eps=1e-6) | |
) | |
# Stage 1 | |
self.stage1 = nn.Sequential(*[ConvNeXtBlock(96) for _ in range(3)]) | |
# Downsample 1 | |
self.downsample1 = nn.Sequential( | |
nn.LayerNorm([96, 56, 56], eps=1e-6), | |
nn.Conv2d(96, 192, kernel_size=2, stride=2) | |
) | |
# Stage 2 | |
self.stage2 = nn.Sequential(*[ConvNeXtBlock(192) for _ in range(3)]) | |
# Downsample 2 | |
self.downsample2 = nn.Sequential( | |
nn.LayerNorm([192, 28, 28], eps=1e-6), | |
nn.Conv2d(192, 384, kernel_size=2, stride=2) | |
) | |
# Stage 3 | |
self.stage3 = nn.Sequential(*[ConvNeXtBlock(384) for _ in range(9)]) | |
# Downsample 3 | |
self.downsample3 = nn.Sequential( | |
nn.LayerNorm([384, 14, 14], eps=1e-6), | |
nn.Conv2d(384, 768, kernel_size=2, stride=2) | |
) | |
# Stage 4 | |
self.stage4 = nn.Sequential(*[ConvNeXtBlock(768) for _ in range(3)]) | |
# Head | |
self.avgpool = nn.AdaptiveAvgPool2d(1) | |
self.norm = nn.LayerNorm(768, eps=1e-6) | |
self.head = nn.Linear(768, num_classes) | |
def forward(self, x): | |
x = self.stem(x) | |
x = self.stage1(x) | |
x = self.downsample1(x) | |
x = self.stage2(x) | |
x = self.downsample2(x) | |
x = self.stage3(x) | |
x = self.downsample3(x) | |
x = self.stage4(x) | |
x = self.avgpool(x) | |
x = x.view(x.size(0), -1) | |
x = self.norm(x) | |
x = self.head(x) | |
return x | |
# Cache the model loading to avoid reloading on every interaction | |
def load_model(): | |
"""Load the pre-trained ViT model for sports ball classification""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
try: | |
# Download model weights from Hugging Face | |
model_url = "https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier/resolve/main/model.pth" | |
response = requests.get(model_url) | |
if response.status_code != 200: | |
raise Exception(f"Failed to download model: HTTP {response.status_code}") | |
model_state = torch.load(io.BytesIO(response.content), map_location=device) | |
# Inspect the state dict to understand the model structure | |
sample_keys = list(model_state.keys())[:10] | |
# Try Vision Transformer models (this is likely what was used) | |
vit_models_to_try = [ | |
("vit_base_patch16_224", lambda: create_model('vit_base_patch16_224', pretrained=False, num_classes=15)), | |
("vit_small_patch16_224", lambda: create_model('vit_small_patch16_224', pretrained=False, num_classes=15)), | |
("vit_tiny_patch16_224", lambda: create_model('vit_tiny_patch16_224', pretrained=False, num_classes=15)), | |
("vit_large_patch16_224", lambda: create_model('vit_large_patch16_224', pretrained=False, num_classes=15)), | |
("vit_base_patch32_224", lambda: create_model('vit_base_patch32_224', pretrained=False, num_classes=15)), | |
] | |
st.info("Trying Vision Transformer (ViT) models...") | |
for model_name, model_func in vit_models_to_try: | |
try: | |
model = model_func() | |
model.load_state_dict(model_state) | |
model.eval() | |
model.to(device) | |
st.success(f"β Successfully loaded model using: {model_name}") | |
return model, device | |
except Exception as e: | |
st.warning(f"β Failed to load with {model_name}: {str(e)[:100]}...") | |
continue | |
# Try ConvNeXt models as fallback | |
convnext_models_to_try = [ | |
("convnext_tiny", lambda: create_model('convnext_tiny', pretrained=False, num_classes=15)), | |
("convnext_small", lambda: create_model('convnext_small', pretrained=False, num_classes=15)), | |
("convnext_base", lambda: create_model('convnext_base', pretrained=False, num_classes=15)), | |
] | |
st.info("Trying ConvNeXt models as fallback...") | |
for model_name, model_func in convnext_models_to_try: | |
try: | |
model = model_func() | |
model.load_state_dict(model_state) | |
model.eval() | |
model.to(device) | |
st.success(f"β Successfully loaded model using: {model_name}") | |
return model, device | |
except Exception as e: | |
st.warning(f"β Failed to load with {model_name}: {str(e)[:100]}...") | |
continue | |
# Try other common models | |
other_models_to_try = [ | |
("resnet50", lambda: create_model('resnet50', pretrained=False, num_classes=15)), | |
("efficientnet_b0", lambda: create_model('efficientnet_b0', pretrained=False, num_classes=15)), | |
("mobilenetv3_large_100", lambda: create_model('mobilenetv3_large_100', pretrained=False, num_classes=15)), | |
] | |
st.info("Trying other model architectures...") | |
for model_name, model_func in other_models_to_try: | |
try: | |
model = model_func() | |
model.load_state_dict(model_state) | |
model.eval() | |
model.to(device) | |
st.success(f"β Successfully loaded model using: {model_name}") | |
return model, device | |
except Exception as e: | |
st.warning(f"β Failed to load with {model_name}: {str(e)[:100]}...") | |
continue | |
# If all fail, try loading with strict=False and show detailed info | |
st.info("Attempting to load with strict=False...") | |
try: | |
# Try with the most common ViT model first | |
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=15) | |
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) | |
if missing_keys: | |
st.warning(f"β οΈ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...") | |
if unexpected_keys: | |
st.warning(f"β οΈ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...") | |
model.eval() | |
model.to(device) | |
if len(missing_keys) > 0 or len(unexpected_keys) > 0: | |
st.error("β οΈ Model loaded with mismatched weights - predictions will likely be unreliable!") | |
st.info("π‘ The saved model might have been trained with a different architecture.") | |
else: | |
st.success("β Model loaded successfully with strict=False") | |
return model, device | |
except Exception as e: | |
st.error(f"β Failed to load model with all methods. Error: {str(e)}") | |
st.info("π‘ Try checking the model file or re-training with a compatible architecture.") | |
return None, device | |
except Exception as e: | |
st.error(f"β Error downloading or loading model: {str(e)}") | |
return None, device | |
def get_transform(): | |
"""Define image preprocessing transforms""" | |
return transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def predict_image(image, model, device, transform, label_names, topk=5): | |
"""Make predictions on uploaded image""" | |
# Transform image | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
# Predict | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
probs = F.softmax(outputs, dim=1) | |
top_probs, top_idxs = torch.topk(probs, k=topk) | |
# Convert to CPU for display | |
top_probs = top_probs[0].cpu().numpy() | |
top_idxs = top_idxs[0].cpu().numpy() | |
return top_probs, top_idxs | |
def main(): | |
st.title("π Sports Ball Classifier") | |
st.markdown("Upload an image of a sports ball and get AI-powered predictions!") | |
# Define label names | |
label_names = [ | |
'american_football', 'baseball', 'basketball', 'billiard_ball', | |
'bowling_ball', 'cricket_ball', 'football', 'golf_ball', | |
'hockey_ball', 'hockey_puck', 'rugby_ball', 'shuttlecock', | |
'table_tennis_ball', 'tennis_ball', 'volleyball' | |
] | |
# Load model | |
with st.spinner("Loading model..."): | |
model, device = load_model() | |
if model is None: | |
st.error("Failed to load model. Please try again later.") | |
return | |
st.success(f"Model loaded successfully! Using device: {device}") | |
# Get image transform | |
transform = get_transform() | |
# Create two columns | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.subheader("Upload Image") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", | |
type=['png', 'jpg', 'jpeg'], | |
help="Upload a clear image of a sports ball for best results" | |
) | |
# Number of top predictions to show | |
topk = st.slider("Number of predictions to show:", 1, 10, 5) | |
with col2: | |
st.subheader("Predictions") | |
if uploaded_file is not None: | |
# Display uploaded image | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Make predictions | |
with st.spinner("Analyzing image..."): | |
try: | |
top_probs, top_idxs = predict_image( | |
image, model, device, transform, label_names, topk | |
) | |
# Show original top prediction prominently | |
top_confidence = float(top_probs[0] * 100) | |
top_label = label_names[top_idxs[0]].replace('_', ' ').title() | |
if top_confidence > 70: | |
color = "π’" | |
elif top_confidence > 40: | |
color = "π‘" | |
else: | |
color = "π΄" | |
st.success(f"{color} **Primary Prediction: {top_label}** ({top_confidence:.2f}%)") | |
st.progress(float(top_confidence / 100)) | |
# Show top 3 high confidence predictions | |
st.subheader("Top 3 Predictions:") | |
for i in range(min(3, len(top_probs))): | |
confidence = float(top_probs[i] * 100) | |
label = label_names[top_idxs[i]].replace('_', ' ').title() | |
# Color coding based on confidence | |
if confidence > 70: | |
color = "π’" | |
elif confidence > 40: | |
color = "π‘" | |
else: | |
color = "π΄" | |
st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%") | |
# Progress bar for confidence (convert to Python float) | |
st.progress(float(confidence / 100)) | |
# Show all predictions if user wants more | |
if topk > 3: | |
with st.expander(f"See all {topk} predictions"): | |
for i in range(3, len(top_probs)): | |
confidence = float(top_probs[i] * 100) | |
label = label_names[top_idxs[i]].replace('_', ' ').title() | |
if confidence > 70: | |
color = "π’" | |
elif confidence > 40: | |
color = "π‘" | |
else: | |
color = "π΄" | |
st.write(f"{i+1}. {color} **{label}**: {confidence:.2f}%") | |
st.progress(float(confidence / 100)) | |
# Show detailed results in expandable section | |
with st.expander("Detailed Results"): | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
labels = [label_names[idx].replace('_', ' ').title() for idx in top_idxs] | |
probabilities = [float(prob * 100) for prob in top_probs] # Convert to Python float | |
bars = ax.barh(labels[::-1], probabilities[::-1]) | |
ax.set_xlabel('Confidence (%)') | |
ax.set_title(f'Top {topk} Predictions') | |
ax.set_xlim(0, 100) | |
# Color bars based on confidence | |
for bar, prob in zip(bars, probabilities[::-1]): | |
if prob > 70: | |
bar.set_color('#4CAF50') # Green | |
elif prob > 40: | |
bar.set_color('#FF9800') # Orange | |
else: | |
bar.set_color('#F44336') # Red | |
# Add percentage labels on bars | |
for i, (bar, prob) in enumerate(zip(bars, probabilities[::-1])): | |
ax.text(float(prob) + 1, bar.get_y() + bar.get_height()/2, | |
f'{float(prob):.1f}%', va='center') | |
plt.tight_layout() | |
st.pyplot(fig) | |
except Exception as e: | |
st.error(f"Error during prediction: {str(e)}") | |
else: | |
st.info("π Please upload an image to get started!") | |
# Additional information | |
st.markdown("---") | |
st.subheader("Supported Sports Balls") | |
# Display supported categories in a nice grid | |
cols = st.columns(5) | |
for i, label in enumerate(label_names): | |
with cols[i % 5]: | |
st.write(f"β’ {label.replace('_', ' ').title()}") | |
st.markdown("---") | |
st.markdown(""" | |
**About this model:** | |
- Built using ConvNeXt architecture | |
- Trained to classify 15 different types of sports balls | |
- Model weights from: [Alamgirapi/sports-ball-convnext-classifier](https://huggingface.co/Alamgirapi/sports-ball-convnext-classifier) | |
**Tips for best results:** | |
- Use clear, well-lit images | |
- Ensure the ball is the main subject | |
- Avoid cluttered backgrounds when possible | |
""") | |
if __name__ == "__main__": | |
main() |