damage-classifier-multi-task
Home Damage Classification Model
This model was trained to classify damage to household items, identifying both the item category, damage type, and damage severity.
Model Type
- Architecture: Vision Transformer (ViT)
- Classification Approach: Multi_Task
- Base Model: google/vit-base-patch16-224
Supported Categories
Item Categories
- Microwave
- Wall
- Window
- Fence
- Glass
- Fishbowl
Damage Types
- Scratch
- Dent
- Break
- Burn
- Water Damage
Severity Levels
- No Damage
- Minor Damage
- Moderate Damage
- Severe Damage
Multi-Task Architecture
This model uses a multi-task learning approach with:
- A shared Vision Transformer (ViT) backbone that extracts features from the input image
- Separate classification heads for:
- Item category identification
- Damage type classification
- Damage severity assessment
This approach allows the model to share knowledge between related tasks while making separate predictions for each aspect.
Advantages of Multi-Task Learning
- Shares knowledge across related tasks
- Requires fewer examples per combination
- Can perform well even with missing combinations
- Independent predictions for each aspect
Usage
from transformers import ViTFeatureExtractor
from PIL import Image
import torch
# Load model and feature extractor
model = torch.load("pytorch_model.bin") # Or use your preferred loading method
feature_extractor = ViTFeatureExtractor.from_pretrained("USER/REPO_NAME")
# Prepare image
image = Image.open("path/to/image.jpg").convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
# Get predictions
outputs = model(**inputs)
# Process multi-task outputs
item_logits = outputs['item_logits']
damage_logits = outputs['damage_type_logits']
severity_logits = outputs['severity_logits']
# Get predicted classes
item_class = torch.argmax(item_logits, dim=1).item()
damage_class = torch.argmax(damage_logits, dim=1).item()
severity_class = torch.argmax(severity_logits, dim=1).item()
# Map to class names (replace with your class mappings)
item_categories = ["microwave", "wall", "window", "fence", "glass", "fishbowl"]
damage_types = ["scratch", "dent", "break", "burn", "water_damage"]
severity_levels = ["no_damage", "minor_damage", "moderate_damage", "severe_damage"]
print(f"Item: {item_categories[item_class]}")
print(f"Damage Type: {damage_types[damage_class]}")
print(f"Severity: {severity_levels[severity_class]}")
For a more complete example, see the inference script in the GitHub repository.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support