hmichaeli commited on
Commit
77185e0
1 Parent(s): c975e16

Add evaluation script

Browse files
Files changed (1) hide show
  1. README.md +51 -0
README.md CHANGED
@@ -27,6 +27,7 @@ git clone https://github.com/hmichaeli/alias_free_convnets.git
27
  ```python
28
  from huggingface_hub import hf_hub_download
29
  import torch
 
30
  from alias_free_convnets.models.convnext_afc import convnext_afc_tiny
31
 
32
  # baseline
@@ -51,6 +52,56 @@ afc_model = convnext_afc_tiny(
51
  )
52
  afc_model.load_state_dict(ckpt, strict=False)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ```
55
 
56
 
 
27
  ```python
28
  from huggingface_hub import hf_hub_download
29
  import torch
30
+ from torchvision import datasets, transforms
31
  from alias_free_convnets.models.convnext_afc import convnext_afc_tiny
32
 
33
  # baseline
 
52
  )
53
  afc_model.load_state_dict(ckpt, strict=False)
54
 
55
+ # evaluate model
56
+ interpolation = transforms.InterpolationMode.BICUBIC
57
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
58
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
59
+ transform = transforms.Compose([
60
+ transforms.Resize(256, interpolation=interpolation),
61
+ transforms.CenterCrop(224),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
64
+ ])
65
+ data_path = "/path/to/imagenet/val"
66
+ dataset_val = datasets.ImageFolder(data_path, transform=transform)
67
+ nb_classes = 1000
68
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
69
+ data_loader_val = torch.utils.data.DataLoader(
70
+ dataset_val, sampler=sampler_val,
71
+ batch_size=8,
72
+ num_workers=8,
73
+ drop_last=False
74
+ )
75
+
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+
78
+ @torch.no_grad()
79
+ def evaluate(data_loader, model, device):
80
+ model.eval()
81
+ correct = 0
82
+ total = 0
83
+ for batch_idx, (inputs, targets) in enumerate(data_loader):
84
+ inputs, targets = inputs.to(device), targets.to(device)
85
+ outputs = model(inputs)
86
+ _, predicted = outputs.max(1)
87
+ total += targets.size(0)
88
+ correct += predicted.eq(targets).sum().item()
89
+
90
+ acc = 100. * correct / total
91
+ print("Acc@1 {:.3f}".format(acc))
92
+
93
+
94
+
95
+ print("evaluate baseline")
96
+ base_model.to(device)
97
+ test_stats = evaluate(data_loader_val, base_model, device)
98
+
99
+ print("evaluate AFC")
100
+ afc_model.to(device)
101
+ test_stats = evaluate(data_loader_val, afc_model, device)
102
+
103
+
104
+
105
  ```
106
 
107