|
--- |
|
license: mit |
|
datasets: |
|
- imagenet-1k |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
pipeline_tag: image-classification |
|
--- |
|
|
|
[Alias-Free Convnets: Fractional Shift Invariance via Polynomial Activations](https://hmichaeli.github.io/alias_free_convnets/) |
|
|
|
Official PyTorch trained model. |
|
|
|
This is a ConvNeXt-Tiny variant. |
|
|
|
`convnext_tiny_baseline` is ConvNeXt-Tiny with circular-padded convolutions. |
|
`convnext_tiny_afc` is the full ConvNeXt-Tiny-AFC which is shift invariant to circular shifts. |
|
|
|
For more details see the [paper](https://arxiv.org/abs/2303.08085) or [implementation](https://github.com/hmichaeli/alias_free_convnets). |
|
|
|
```bash |
|
git clone https://github.com/hmichaeli/alias_free_convnets.git |
|
``` |
|
|
|
```python |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
from torchvision import datasets, transforms |
|
from alias_free_convnets.models.convnext_afc import convnext_afc_tiny |
|
|
|
# baseline |
|
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_basline.pth") |
|
ckpt = torch.load(path) |
|
base_model = convnext_afc_tiny(pretrained=False, num_classes=1000) |
|
base_model.load_state_dict(ckpt, strict=True) |
|
|
|
# AFC |
|
path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_afc.pth") |
|
ckpt = torch.load(path) |
|
afc_model = convnext_afc_tiny( |
|
pretrained=False, |
|
num_classes=1000, |
|
activation='up_poly_per_channel', |
|
activation_kwargs={'in_scale': 7, 'out_scale': 7, 'train_scale': True}, |
|
blurpool_kwargs={"filter_type": "ideal", "scale_l2": False}, |
|
normalization_type='CHW2', |
|
stem_activation_kwargs={"in_scale": 7, "out_scale": 7, "train_scale": True, "cutoff": 0.75}, |
|
normalization_kwargs={}, |
|
stem_mode='activation_residual', stem_activation='lpf_poly_per_channel' |
|
) |
|
afc_model.load_state_dict(ckpt, strict=False) |
|
|
|
# evaluate model |
|
interpolation = transforms.InterpolationMode.BICUBIC |
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
transform = transforms.Compose([ |
|
transforms.Resize(256, interpolation=interpolation), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), |
|
]) |
|
data_path = "/path/to/imagenet/val" |
|
dataset_val = datasets.ImageFolder(data_path, transform=transform) |
|
nb_classes = 1000 |
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
data_loader_val = torch.utils.data.DataLoader( |
|
dataset_val, sampler=sampler_val, |
|
batch_size=8, |
|
num_workers=8, |
|
drop_last=False |
|
) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@torch.no_grad() |
|
def evaluate(data_loader, model, device): |
|
model.eval() |
|
correct = 0 |
|
total = 0 |
|
for batch_idx, (inputs, targets) in enumerate(data_loader): |
|
inputs, targets = inputs.to(device), targets.to(device) |
|
outputs = model(inputs) |
|
_, predicted = outputs.max(1) |
|
total += targets.size(0) |
|
correct += predicted.eq(targets).sum().item() |
|
|
|
acc = 100. * correct / total |
|
print("Acc@1 {:.3f}".format(acc)) |
|
|
|
|
|
|
|
print("evaluate baseline") |
|
base_model.to(device) |
|
test_stats = evaluate(data_loader_val, base_model, device) |
|
|
|
print("evaluate AFC") |
|
afc_model.to(device) |
|
test_stats = evaluate(data_loader_val, afc_model, device) |
|
|
|
|
|
|
|
``` |
|
|
|
|
|
|