notrito's picture
Update app.py
45ee7ed verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet34_Weights
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
import os
import random
import glob
# Import LoRA code
from model import LoRALayer, apply_lora_to_model
# Load model
print("Loading model...")
model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model = apply_lora_to_model(model, rank=8)
# Load trained weights (from local Space files)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()
print("Model loaded successfully!")
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Class names
class_names = ['Non-Smoker', 'Smoker']
def predict(image):
"""
Predict if person in image is smoking
Args:
image: PIL Image
Returns:
dict: Prediction probabilities for each class
"""
if image is None:
return None
# Preprocess
img_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.softmax(outputs, dim=1)[0]
# Format results
results = {
class_names[i]: float(probabilities[i])
for i in range(len(class_names))
}
return results
# Get all example images
example_images = glob.glob("All/*")
examples = [[img] for img in example_images[:12]] # Takes the 12 images
# Function to get random sample
def get_random_sample():
"""Load a random example image"""
random_image_path = random.choice(example_images)
return Image.open(random_image_path)
# Create Gradio interface with custom CSS
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🚬 Smoker Detection
Upload an image or try a random sample to detect if a person is smoking.
This model uses **ResNet34 with LoRA fine-tuning** (only 2.14% of parameters trained)
and achieves **89.73% test accuracy**.
**Model:** [notrito/smoker-detection](https://huggingface.co/notrito/smoker-detection)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
with gr.Row():
predict_btn = gr.Button("πŸ” Predict", variant="primary")
random_btn = gr.Button("🎲 Random Sample", variant="secondary")
with gr.Column():
output_label = gr.Label(num_top_classes=2, label="Prediction")
gr.Markdown("### πŸ“Έ Try these examples:")
gr.Examples(
examples=examples,
inputs=input_image,
outputs=output_label,
fn=predict,
cache_examples=True
)
gr.Markdown(
"""
===================================================================================================
### About this model
- **Architecture:** ResNet34 + LoRA adapters (rank=8)
- **Training:** Fine-tuned on 1,120 images
- **Performance:** 89.73% test accuracy, 89.96% F1-score
- **Efficiency:** Only 465K trainable parameters (2.14% of model)
### How it works
LoRA (Low-Rank Adaptation) freezes the pretrained ImageNet weights and adds small trainable
matrices to specific layers. This prevents overfitting on small datasets while maintaining
the model's powerful feature extraction capabilities.
### Limitations
- Trained on limited dataset (1,120 images)
- Best for frontal/profile views with visible cigarettes
- May not generalize to all smoking scenarios
### Links
- [Model Card](https://huggingface.co/notrito/smoker-detection)
- [Training Notebook](https://www.kaggle.com/code/notrito/smoker-detection-with-lora)
**Author:** Noel Triguero
"""
)
# Connect buttons
predict_btn.click(fn=predict, inputs=input_image, outputs=output_label)
random_btn.click(fn=get_random_sample, inputs=None, outputs=input_image)
if __name__ == "__main__":
demo.launch()