popkek00's picture
Create README.md
add0ce0 verified
metadata
base_model: microsoft/resnet-18
license: mit
tags:
  - image-classification
  - pytorch
  - computer-vision
  - fall-detection

Fall Detection Model (ResNet-18 Fine-tuned)

This model is a fine-tuned ResNet-18 for image classification, specifically trained to detect falls in images.

Model Details

  • Base Model: microsoft/resnet-18
  • Dataset: hiennguyen9874/fall-detection-dataset
  • Task: Binary image classification (fall/no_fall)
  • Classes:
    • 0: no_fall
    • 1: fall

How to Use

1. Load the Model and Image Processor

from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch

# Assuming 'device' is already defined (e.g., torch.device("cuda" if torch.cuda.is_available() else "cpu"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

repo_id = "popkek00/fall_detection_model" # Your model's repository ID

model = AutoModelForImageClassification.from_pretrained(repo_id).to(device)
image_processor = AutoImageProcessor.from_pretrained(repo_id)

model.eval() # Set model to evaluation mode

2. Prepare an Image for Inference

# Example: Load an image (replace with your image path or PIL Image object)
# You can load an image from a URL, local file, or a BytesIO object
# For demonstration, let's assume you have a PIL Image object called `example_image`

# Create a dummy image for demonstration
example_image = Image.new('RGB', (224, 224), color = 'red')

# Process the image
inputs = image_processor(images=example_image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)

3. Get Predictions

with torch.no_grad():
    outputs = model(pixel_values)

logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class_id = probabilities.argmax().item()

# Get the human-readable label from the model's config
predicted_label = model.config.id2label[predicted_class_id]
confidence = probabilities[0, predicted_class_id].item() * 100

print(f"Predicted label: {predicted_label} (Confidence: {confidence:.2f}%)")