Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from torchvision.models import ResNet34_Weights | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import random | |
| import glob | |
| # Import LoRA code | |
| from model import LoRALayer, apply_lora_to_model | |
| # Load model | |
| print("Loading model...") | |
| model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) | |
| model.fc = nn.Linear(model.fc.in_features, 2) | |
| model = apply_lora_to_model(model, rank=8) | |
| # Load trained weights (from local Space files) | |
| model.load_state_dict(torch.load('best_model.pth', map_location='cpu')) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # Preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # Class names | |
| class_names = ['Non-Smoker', 'Smoker'] | |
| def predict(image): | |
| """ | |
| Predict if person in image is smoking | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| dict: Prediction probabilities for each class | |
| """ | |
| if image is None: | |
| return None | |
| # Preprocess | |
| img_tensor = transform(image).unsqueeze(0) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probabilities = torch.softmax(outputs, dim=1)[0] | |
| # Format results | |
| results = { | |
| class_names[i]: float(probabilities[i]) | |
| for i in range(len(class_names)) | |
| } | |
| return results | |
| # Get all example images | |
| example_images = glob.glob("All/*") | |
| examples = [[img] for img in example_images[:12]] # Takes the 12 images | |
| # Function to get random sample | |
| def get_random_sample(): | |
| """Load a random example image""" | |
| random_image_path = random.choice(example_images) | |
| return Image.open(random_image_path) | |
| # Create Gradio interface with custom CSS | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π¬ Smoker Detection | |
| Upload an image or try a random sample to detect if a person is smoking. | |
| This model uses **ResNet34 with LoRA fine-tuning** (only 2.14% of parameters trained) | |
| and achieves **89.73% test accuracy**. | |
| **Model:** [notrito/smoker-detection](https://huggingface.co/notrito/smoker-detection) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| with gr.Row(): | |
| predict_btn = gr.Button("π Predict", variant="primary") | |
| random_btn = gr.Button("π² Random Sample", variant="secondary") | |
| with gr.Column(): | |
| output_label = gr.Label(num_top_classes=2, label="Prediction") | |
| gr.Markdown("### πΈ Try these examples:") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_image, | |
| outputs=output_label, | |
| fn=predict, | |
| cache_examples=True | |
| ) | |
| gr.Markdown( | |
| """ | |
| =================================================================================================== | |
| ### About this model | |
| - **Architecture:** ResNet34 + LoRA adapters (rank=8) | |
| - **Training:** Fine-tuned on 1,120 images | |
| - **Performance:** 89.73% test accuracy, 89.96% F1-score | |
| - **Efficiency:** Only 465K trainable parameters (2.14% of model) | |
| ### How it works | |
| LoRA (Low-Rank Adaptation) freezes the pretrained ImageNet weights and adds small trainable | |
| matrices to specific layers. This prevents overfitting on small datasets while maintaining | |
| the model's powerful feature extraction capabilities. | |
| ### Limitations | |
| - Trained on limited dataset (1,120 images) | |
| - Best for frontal/profile views with visible cigarettes | |
| - May not generalize to all smoking scenarios | |
| ### Links | |
| - [Model Card](https://huggingface.co/notrito/smoker-detection) | |
| - [Training Notebook](https://www.kaggle.com/code/notrito/smoker-detection-with-lora) | |
| **Author:** Noel Triguero | |
| """ | |
| ) | |
| # Connect buttons | |
| predict_btn.click(fn=predict, inputs=input_image, outputs=output_label) | |
| random_btn.click(fn=get_random_sample, inputs=None, outputs=input_image) | |
| if __name__ == "__main__": | |
| demo.launch() |