Edit model card

Convnextv2 finetuned for angle classification

Convnextv2 base-size model finetuned for the classification of camera angles. Cinescale dataset is used to finetune the model for 30 epochs.

Classifies an image into five classes: dutch, high, low, neutral, overhead

Evaluation

On the test set (test.csv), the model has an accuracy of 93.32% and macro-f1 of 90.01%

How to use

from transformers import AutoModelForImageClassification
import torch
from torchvision.transforms import v2
from torchvision.io import read_image, ImageReadMode

model = AutoModelForImageClassification.from_pretrained("gullalc/convnextv2-base-22k-224-cinescale-angle")
im_size = 224

## https://www.pexels.com/photo/man-in-black-dress-walking-in-between-brown-wooden-pews-9614069/
image = read_image("demo/angle_demo.jpg", mode=ImageReadMode.RGB)

transform = v2.Compose([v2.Resize(im_size, antialias=True), 
                            v2.CenterCrop((im_size,im_size)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

inputs = transform(image).unsqueeze(0)

with torch.no_grad():
    outputs = model(pixel_values=inputs)
    

predicted_label = model.config.id2label[torch.argmax(outputs.logits).item()]
print(predicted_label)
# --> high

Training Details

## Training transforms
randomorder = v2.RandomOrder([
                            v2.RandomHorizontalFlip(),
                            v2.GaussianBlur(5),
                            v2.RandomAdjustSharpness(2),
                            v2.RandomGrayscale(p=0.2),
                            v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)])

train_transform = v2.Compose([v2.Resize(im_size, antialias=True), 
                                    v2.RandomResizedCrop((im_size,im_size), antialias=True),
                                    randomorder,                                
                                    v2.ToDtype(torch.float32, scale=True), 
                                    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

## Training Arguments
training_args = TrainingArguments(
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=128,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=128,
    num_train_epochs=30,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    dataloader_num_workers=32,
    torch_compile=True
)
Downloads last month
3
Safetensors
Model size
87.7M params
Tensor type
F32
·
Inference API
Drag image file here or click to browse from your device
This model can be loaded on Inference API (serverless).