Arko007's picture
Create README.md
579ed70 verified
metadata
license: apache-2.0
language:
  - en
tags:
  - image-classification
  - walnut
  - defect-detection
  - efficientnet
  - timm
  - pytorch
  - surface-defect
  - quality-control
pipeline_tag: image-classification
library_name: timm
base_model: timm/efficientnet_b3.ra2_in1k
metrics:
  - accuracy
  - f1
model-index:
  - name: Walnut Shell Defect Classifier
    results:
      - task:
          type: image-classification
          name: Image Classification
        dataset:
          name: Nut Surface Defect Dataset (nutsv2ifolder_split)
          type: weihaoreal/nut-surface-defect-dataset
        metrics:
          - type: accuracy
            value: 0.9855
            name: Validation Accuracy
          - type: f1
            value: 0.98
            name: Macro F1

Walnut Shell Defect Classifier

EfficientNet-B3 finetuned for walnut shell defect classification across 4 categories. Trained on the Nut Surface Defect Dataset with class remapping to match walnut-specific defect taxonomy.

Classes

Output Label Remapped From (Dataset)
Healthy Excellent
Black Spot Rusting
Shriveled Scratches
Damaged Deformation + Fracture

Metrics (Epoch 8 — Best Checkpoint)

Class Precision Recall F1
Healthy 0.88 1.00 0.93
Black Spot 1.00 0.99 1.00
Shriveled 1.00 0.98 0.99
Damaged 1.00 0.98 0.99
Macro Avg 0.97 0.99 0.98
Weighted Avg 0.99 0.99 0.99

Val Accuracy: 98.55% | Macro F1: 0.98

Training Setup

Parameter Value
Base Model EfficientNet-B3 (pretrained ImageNet)
Image Size 512×512 px
Batch Size 18 per GPU × 2 T4 = 36 effective
Optimizer AdamW (lr=2e-5, wd=1e-2)
Scheduler Cosine Annealing + 3-epoch warmup
Precision FP16 (torch.cuda.amp)
Drop Rate 0.4
Label Smoothing 0.05
Early Stop Patience 7 epochs
Hardware Kaggle 2× NVIDIA T4 (16 GB each)

Inference

import torch, timm
from PIL import Image
import torchvision.transforms as transforms

CLASSES = ["Healthy", "Black Spot", "Shriveled", "Damaged"]

model = timm.create_model("efficientnet_b3", pretrained=False,
                           num_classes=4, drop_rate=0.4)
ckpt  = torch.load("best_model.pth", map_location="cpu")
state = {k.replace("module.", ""): v for k, v in ckpt["model_state_dict"].items()}
model.load_state_dict(state)
model.eval()

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

img    = Image.open("walnut.jpg").convert("RGB")
x      = transform(img).unsqueeze(0)
probs  = torch.softmax(model(x), dim=1)
conf, idx = probs.max(0)
print({"defect_class": CLASSES[idx.item()], "confidence": round(conf.item(), 4)})

License

Apache 2.0