uoft-cs/cifar10
Viewer • Updated • 60k • 119k • 105
Model ResNet18 adaptat si antrenat pe CIFAR-10.
| Metrica | Valoare |
|---|---|
| Accuracy | 0.8807 |
| Precision | 0.8823 |
| Recall | 0.8807 |
| F1-Score | 0.8806 |
# In template, seteaza:
HUGGINGFACE_REPO_ID = "Tudorx95/resnet18-cifar10-pytorch"
MODEL_FILENAME = "ResNet18_CIFAR10.pth"
# Modelul se incarca automat in create_model()
import torch
import torchvision
import torch.nn as nn
# Incarca checkpoint-ul
checkpoint = torch.load('ResNet18_CIFAR10.pth', map_location='cpu')
# Reconstruieste arhitectura din config
arch = checkpoint['architecture']
model = torchvision.models.resnet18(weights=None)
mod = arch['modifications']
model.conv1 = nn.Conv2d(3, 64, **mod['conv1'])
model.maxpool = nn.Identity()
model.fc = nn.Linear(mod['fc']['in_features'], mod['fc']['out_features'])
# Incarca ponderile
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
0: airplane, 1: automobile, 2: bird, 3: cat, 4: deer, 5: dog, 6: frog, 7: horse, 8: ship, 9: truck