damage-classifier-multi-task

Home Damage Classification Model

This model was trained to classify damage to household items, identifying both the item category, damage type, and damage severity.

Model Type

  • Architecture: Vision Transformer (ViT)
  • Classification Approach: Multi_Task
  • Base Model: google/vit-base-patch16-224

Supported Categories

Item Categories

  • Microwave
  • Wall
  • Window
  • Fence
  • Glass
  • Fishbowl

Damage Types

  • Scratch
  • Dent
  • Break
  • Burn
  • Water Damage

Severity Levels

  • No Damage
  • Minor Damage
  • Moderate Damage
  • Severe Damage

Multi-Task Architecture

This model uses a multi-task learning approach with:

  1. A shared Vision Transformer (ViT) backbone that extracts features from the input image
  2. Separate classification heads for:
    • Item category identification
    • Damage type classification
    • Damage severity assessment

This approach allows the model to share knowledge between related tasks while making separate predictions for each aspect.

Advantages of Multi-Task Learning

  • Shares knowledge across related tasks
  • Requires fewer examples per combination
  • Can perform well even with missing combinations
  • Independent predictions for each aspect

Usage

from transformers import ViTFeatureExtractor
from PIL import Image
import torch

# Load model and feature extractor
model = torch.load("pytorch_model.bin")  # Or use your preferred loading method
feature_extractor = ViTFeatureExtractor.from_pretrained("USER/REPO_NAME")

# Prepare image
image = Image.open("path/to/image.jpg").convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")

# Get predictions
outputs = model(**inputs)

# Process multi-task outputs
item_logits = outputs['item_logits']
damage_logits = outputs['damage_type_logits']
severity_logits = outputs['severity_logits']

# Get predicted classes
item_class = torch.argmax(item_logits, dim=1).item()
damage_class = torch.argmax(damage_logits, dim=1).item()
severity_class = torch.argmax(severity_logits, dim=1).item()

# Map to class names (replace with your class mappings)
item_categories = ["microwave", "wall", "window", "fence", "glass", "fishbowl"]
damage_types = ["scratch", "dent", "break", "burn", "water_damage"]
severity_levels = ["no_damage", "minor_damage", "moderate_damage", "severe_damage"]

print(f"Item: {item_categories[item_class]}")
print(f"Damage Type: {damage_types[damage_class]}")
print(f"Severity: {severity_levels[severity_class]}")

For a more complete example, see the inference script in the GitHub repository.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support