VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials

This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our torch-conv-kan repository on GitHub.

Model description

The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head:

  1. BottleNeckKAGN Convolution, 32 filters, 3x3
  2. Max pooling, 2x2
  3. BottleNeckKAGN Convolution, 64 filters, 3x3
  4. Max pooling, 2x2
  5. BottleNeckKAGN Convolution, 128 filters, 3x3
  6. BottleNeckKAGN Convolution, 128 filters, 3x3
  7. Max pooling, 2x2
  8. BottleNeckKAGN Convolution, 256 filters, 3x3
  9. BottleNeckKAGN Convolution, 256 filters, 3x3 10 Max pooling, 2x2
  10. BottleNeckKAGN Convolution, 256 filters, 3x3
  11. BottleNeckKAGN Convolution, 256 filters, 3x3
  12. Max pooling, 2x2
  13. BottleNeckKAGN Convolution, 512 filters, 3x3
  14. BottleNeckKAGN Convolution, 512 filters, 3x3
  15. BottleNeckSelfKAGNtention, 512 filters, 3x3
  16. Global Average pooling
  17. Output layer, 1000 nodes.

model image

Intended uses & limitations

You can use the raw model for image classification or use it as pretrained model for further finetuning.

How to use

First, clone the repository:

git clone https://github.com/IvanDrokin/torch-conv-kan.git
cd torch-conv-kan
pip install -r requirements.txt

Then you can initialize the model and load weights.

import torch
from models import vggkagn


model = vggkagn_bn(
        3,
        1000,
        groups=1,
        degree=5,
        dropout= 0.05,
        l1_decay=0,
        width_scale=2,
        affine=True,
        norm_layer=nn.BatchNorm2d,
        expected_feature_shape=(1, 1),
        vgg_type='VGG11v4',
        last_attention=True,
        sa_inner_projection=None
)

model.from_pretrained('brivangl/vgg_kagn_bn11sa_v4')

Transforms, used for validation on Imagenet1k:

from torchvision.transforms import v2


transforms_val = v2.Compose([
        v2.ToImage(),
        v2.Resize(256, antialias=True),
        v2.CenterCrop(224),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

Training data

This model trained on Imagenet1k dataset (1281167 images in train set)

Training procedure

Model was trained during 200 full epochs with AdamW optimizer, with following parameters:

{'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06,
'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False}

And this augmnetations:

transforms_train = v2.Compose([
    v2.ToImage(),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomResizedCrop(224, antialias=True),
    v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
                     v2.AutoAugment(AutoAugmentPolicy.IMAGENET)
                     ]),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Evaluation results

On Imagenet1k Validation:

Accuracy, top1 Accuracy, top5 AUC (ovo) AUC (ovr)
70.684 89.462 99.624 99.624

On Imagenet1k Test: Coming soon

BibTeX entry and citation info

If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.

@misc{torch-conv-kan,
  author = {Ivan Drokin},
  title = {Torch Conv KAN},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/IvanDrokin/torch-conv-kan}}
}

References

Downloads last month
18
Safetensors
Model size
12.6M params
Tensor type
F32
·
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Collection including brivangl/vgg_kagn_bn11sa_v4