ResNet18 pentru CIFAR-10

Model ResNet18 adaptat si antrenat pe CIFAR-10.

Performanta

Metrica Valoare
Accuracy 0.8807
Precision 0.8823
Recall 0.8807
F1-Score 0.8806

Utilizare cu template_antrenare_pytorch.py

# In template, seteaza:
HUGGINGFACE_REPO_ID = "Tudorx95/resnet18-cifar10-pytorch"
MODEL_FILENAME = "ResNet18_CIFAR10.pth"

# Modelul se incarca automat in create_model()

Utilizare directa

import torch
import torchvision
import torch.nn as nn

# Incarca checkpoint-ul
checkpoint = torch.load('ResNet18_CIFAR10.pth', map_location='cpu')

# Reconstruieste arhitectura din config
arch = checkpoint['architecture']
model = torchvision.models.resnet18(weights=None)
mod = arch['modifications']
model.conv1 = nn.Conv2d(3, 64, **mod['conv1'])
model.maxpool = nn.Identity()
model.fc = nn.Linear(mod['fc']['in_features'], mod['fc']['out_features'])

# Incarca ponderile
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Clase CIFAR-10

0: airplane, 1: automobile, 2: bird, 3: cat, 4: deer, 5: dog, 6: frog, 7: horse, 8: ship, 9: truck

Antrenare

  • Epochs: 10
  • Batch Size: 128
  • Learning Rate: 0.001
  • Optimizer: Adam
  • Scheduler: StepLR (step=5, gamma=0.5)
  • Augmentare: RandomCrop, RandomHorizontalFlip
Downloads last month
2
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train Tudorx95/resnet18-cifar10-pytorch