--- license: mit datasets: - imagenet-1k language: - en metrics: - accuracy pipeline_tag: image-classification --- [Alias-Free Convnets: Fractional Shift Invariance via Polynomial Activations](https://hmichaeli.github.io/alias_free_convnets/) Official PyTorch trained model. This is a ConvNeXt-Tiny variant. `convnext_tiny_baseline` is ConvNeXt-Tiny with circular-padded convolutions. `convnext_tiny_afc` is the full ConvNeXt-Tiny-AFC which is shift invariant to circular shifts. For more details see the [paper](https://arxiv.org/abs/2303.08085) or [implementation](https://github.com/hmichaeli/alias_free_convnets). ```bash git clone https://github.com/hmichaeli/alias_free_convnets.git ``` ```python from huggingface_hub import hf_hub_download import torch from torchvision import datasets, transforms from alias_free_convnets.models.convnext_afc import convnext_afc_tiny # baseline path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_basline.pth") ckpt = torch.load(path) base_model = convnext_afc_tiny(pretrained=False, num_classes=1000) base_model.load_state_dict(ckpt, strict=True) # AFC path = hf_hub_download(repo_id="hmichaeli/convnext-afc", filename="convnext_tiny_afc.pth") ckpt = torch.load(path) afc_model = convnext_afc_tiny( pretrained=False, num_classes=1000, activation='up_poly_per_channel', activation_kwargs={'in_scale': 7, 'out_scale': 7, 'train_scale': True}, blurpool_kwargs={"filter_type": "ideal", "scale_l2": False}, normalization_type='CHW2', stem_activation_kwargs={"in_scale": 7, "out_scale": 7, "train_scale": True, "cutoff": 0.75}, normalization_kwargs={}, stem_mode='activation_residual', stem_activation='lpf_poly_per_channel' ) afc_model.load_state_dict(ckpt, strict=False) # evaluate model interpolation = transforms.InterpolationMode.BICUBIC IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) transform = transforms.Compose([ transforms.Resize(256, interpolation=interpolation), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ]) data_path = "/path/to/imagenet/val" dataset_val = datasets.ImageFolder(data_path, transform=transform) nb_classes = 1000 sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=8, num_workers=8, drop_last=False ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @torch.no_grad() def evaluate(data_loader, model, device): model.eval() correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(data_loader): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() acc = 100. * correct / total print("Acc@1 {:.3f}".format(acc)) print("evaluate baseline") base_model.to(device) test_stats = evaluate(data_loader_val, base_model, device) print("evaluate AFC") afc_model.to(device) test_stats = evaluate(data_loader_val, afc_model, device) ```