File size: 3,334 Bytes
2692210
 
c1333ba
 
 
 
 
 
 
2692210
c1333ba
 
 
 
 
 
 
c975e16
 
c1333ba
c975e16
c1333ba
 
 
 
 
 
 
 
77185e0
c1333ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77185e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1333ba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
---
license: mit
datasets:
- imagenet-1k
language:
- en
metrics:
- accuracy
pipeline_tag: image-classification
---

[Alias-Free Convnets: Fractional Shift Invariance via Polynomial Activations](https://hmichaeli.github.io/alias_free_convnets/)

Official PyTorch trained model.

This is a ConvNeXt-Tiny variant. 

`convnext_tiny_baseline` is ConvNeXt-Tiny with circular-padded convolutions. 
`convnext_tiny_afc` is the full ConvNeXt-Tiny-AFC which is shift invariant to circular shifts.

For more details see the [paper](https://arxiv.org/abs/2303.08085) or [implementation](https://github.com/hmichaeli/alias_free_convnets).

```bash
git clone https://github.com/hmichaeli/alias_free_convnets.git
```

```python
from huggingface_hub import hf_hub_download
import torch
from torchvision import datasets, transforms
from alias_free_convnets.models.convnext_afc import convnext_afc_tiny

# baseline
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_basline.pth")
ckpt = torch.load(path)
base_model = convnext_afc_tiny(pretrained=False, num_classes=1000)
base_model.load_state_dict(ckpt, strict=True)

# AFC
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_afc.pth")
ckpt = torch.load(path)
afc_model = convnext_afc_tiny(
        pretrained=False,
        num_classes=1000,
        activation='up_poly_per_channel',
        activation_kwargs={'in_scale': 7, 'out_scale': 7, 'train_scale': True},
        blurpool_kwargs={"filter_type": "ideal", "scale_l2": False},
        normalization_type='CHW2',
        stem_activation_kwargs={"in_scale": 7, "out_scale": 7, "train_scale": True, "cutoff": 0.75},
        normalization_kwargs={},
        stem_mode='activation_residual', stem_activation='lpf_poly_per_channel'
    )
afc_model.load_state_dict(ckpt, strict=False)

# evaluate model
interpolation = transforms.InterpolationMode.BICUBIC
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
transform = transforms.Compose([
    transforms.Resize(256, interpolation=interpolation),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
data_path = "/path/to/imagenet/val"
dataset_val = datasets.ImageFolder(data_path, transform=transform)
nb_classes = 1000
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
            dataset_val, sampler=sampler_val,
            batch_size=8,
            num_workers=8,
            drop_last=False
        )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@torch.no_grad()
def evaluate(data_loader, model, device):
    model.eval()
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    print("Acc@1 {:.3f}".format(acc))



print("evaluate baseline")
base_model.to(device)
test_stats = evaluate(data_loader_val, base_model, device)

print("evaluate AFC")
afc_model.to(device)
test_stats = evaluate(data_loader_val, afc_model, device)



```