English
self-supervised learning
barlow-twins
6 papers
File size: 11,255 Bytes
edf8c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53757e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bde0498
53757e4
 
 
 
bde0498
 
 
 
53757e4
 
 
 
bde0498
 
 
 
53757e4
 
 
 
bde0498
 
 
 
53757e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf8c31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
---
license: mit
datasets:
- cifar10
- cifar100
- zh-plus/tiny-imagenet
- imagenet-1k
- jxie/stl10
language:
- en
metrics:
- accuracy
tags:
- self-supervised learning
- barlow-twins
---
# Mixed Barlow Twins
[**Guarding Barlow Twins Against Overfitting with Mixed Samples**](https://arxiv.org/abs/2312.02151)<br>

[Wele Gedara Chaminda Bandara](https://www.wgcban.com) (Johns Hopkins University), [Celso M. De Melo](https://celsodemelo.net) (U.S. Army Research Laboratory), and [Vishal M. Patel](https://engineering.jhu.edu/vpatel36/) (Johns Hopkins University) <br>

## 1 Overview of Mixed Barlow Twins

TL;DR
- Mixed Barlow Twins aims to improve sample interaction during Barlow Twins training via linearly interpolated samples. 
- We introduce an additional regularization term to the original Barlow Twins objective, assuming linear interpolation in the input space translates to linearly interpolated features in the feature space.
- Pre-training with this regularization effectively mitigates feature overfitting and further enhances the downstream performance on `CIFAR-10`, `CIFAR-100`, `TinyImageNet`, `STL-10`, and `ImageNet` datasets.

<img src="figs/mix-bt.svg" width="1024">

$C^{MA} = (Z^M)^TZ^A$

$C^{MB} = (Z^M)^TZ^B$

$C^{MA}_{gt} = \lambda (Z^A)^TZ^A + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^A$

$C^{MB}_{gt} = \lambda (Z^A)^TZ^B + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^B$

## 2 Usage
### 2.1 Requirements

Before using this repository, make sure you have the following prerequisites installed:

- [Anaconda](https://www.anaconda.com/download/)
- [PyTorch](https://pytorch.org)

You can install PyTorch with the following [command](https://pytorch.org/get-started/locally/) (in Linux OS):
```bash
conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia
```

### 2.2 Installation

To get started, clone this repository:
```bash
git clone https://github.com/wgcban/mix-bt.git
```

Next, create the [conda](https://docs.conda.io/projects/conda/en/stable/) environment named `ssl-aug` by executing the following command:
```bash
conda env create -f environment.yml
```

All the train-val-test statistics will be automatically upload to [`wandb`](https://wandb.ai/home), and please refer [`wandb-quick-start`](https://wandb.ai/quickstart?utm_source=app-resource-center&utm_medium=app&utm_term=quickstart) documentation if you are not familiar with using `wandb`. 

### 2.3 Supported Pre-training Datasets

This repository supports the following pre-training datasets:
- `CIFAR-10`: https://www.cs.toronto.edu/~kriz/cifar.html
- `CIFAR-100`: https://www.cs.toronto.edu/~kriz/cifar.html
- `Tiny-ImageNet`: https://github.com/rmccorm4/Tiny-Imagenet-200
- `STL-10`: https://cs.stanford.edu/~acoates/stl10/
- `ImageNet`: https://www.image-net.org

`CIFAR-10`, `CIFAR-100`, and `STL-10` datasets are directly available in PyTorch. 

To use `TinyImageNet`, please follow the preprocessing instructions provided in the [TinyImageNet-Script](https://gist.github.com/moskomule/2e6a9a463f50447beca4e64ab4699ac4). Download these datasets and place them in the `data` directory.

### 2.4 Supported Transfer Learning Datasets
You can download and place transfer learning datasets under their respective paths, such as 'data/DTD'. The supported transfer learning datasets include:
- `DTD`: https://www.robots.ox.ac.uk/~vgg/data/dtd/ 
- `MNIST`: http://yann.lecun.com/exdb/mnist/
- `FashionMNIST`: https://github.com/zalandoresearch/fashion-mnist
- `CUBirds`: http://www.vision.caltech.edu/visipedia/CUB-200-2011.html
- `VGGFlower`: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
- `Traffic Signs`: https://benchmark.ini.rub.de/gtsdb_dataset.html
- `Aircraft`: https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/

### 2.5 Supported SSL Methods

This repository supports the following Self-Supervised Learning (SSL) methods:

- [`SimCLR`](https://arxiv.org/abs/2002.05709): contrastive learning for SSL 
- [`BYOL`](https://arxiv.org/abs/2006.07733): distilation for SSL
- [`Witening MSE`](http://proceedings.mlr.press/v139/ermolov21a/ermolov21a.pdf): infomax for SSL
- [`Barlow Twins`](https://arxiv.org/abs/2103.03230): infomax for SSL
- **`Mixed Barlow Twins (ours)`**: infomax + mixed samples for SSL

### 2.6 Pre-Training with Mixed Barlow Twins
To start pre-training and obtain k-NN evaluation results for Mixed Barlow Twins on `CIFAR-10`, `CIFAR-100`, `TinyImageNet`, and `STL-10` with `ResNet-18/50` backbones, please run:
```bash
sh scripts-pretrain-resnet18/[dataset].sh
```
```bash
sh scripts-pretrain-resnet50/[dataset].sh
```

To start the pre-training on `ImageNet` with `ResNet-50` backbone, please run:
```bash
sh scripts-pretrain-resnet18/imagenet.sh
```

### 2.7 Linear Evaluation of Pre-trained Models
Before running linear evaluation, *ensure that you specify the `model_path` argument correctly in the corresponding .sh file*. 

To obtain linear evaluation results on `CIFAR-10`, `CIFAR-100`, `TinyImageNet`, `STL-10` with `ResNet-18/50` backbones, please run:
```bash
sh scripts-linear-resnet18/[dataset].sh
```
```bash
sh scripts-linear-resnet50/[dataset].sh
```

To obtain linear evaluation results on `ImageNet` with `ResNet-50` backbone, please run:
```bash
sh scripts-linear-resnet50/imagenet_sup.sh
```


### 2.8 Transfer Learning of Pre-trained Models
To perform transfer learning from pre-trained models on `CIFAR-10`, `CIFAR-100`, and `STL-10` to fine-grained classification datasets, execute the following command, making sure to specify the `model_path` argument correctly:
```bash
sh scripts-transfer-resnet18/[dataset]-to-x.sh
```

## 3 Pre-Trained Checkpoints
Download the pre-trained models from `checkpoints/` and store them in `checkpoints/`. This repository provides pre-trained checkpoints for both [`ResNet-18`](https://arxiv.org/abs/1512.03385) and [`ResNet-50`](https://arxiv.org/abs/1512.03385) architectures.

#### 3.1 ResNet-18
| Dataset        |  $d$   | $\lambda_{BT}$ | $\lambda_{reg}$ | Download Link to Pretrained Model | KNN Acc. | Linear Acc. |
| ----------     | ---  | ---------- | ---------- | ------------------------ | -------- | ----------- |
| `CIFAR-10`       | 1024 | 0.0078125  | 4.0        | [4wdhbpcf_0.0078125_1024_256_cifar10_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/4wdhbpcf_0.0078125_1024_256_cifar10_model.pth)     | 90.52    | 92.58        |
| `CIFAR-100`     | 1024 | 0.0078125  | 4.0        | [76kk7scz_0.0078125_1024_256_cifar100_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/76kk7scz_0.0078125_1024_256_cifar100_model.pth)     | 61.25     | 69.31        |
| `TinyImageNet`   | 1024 | 0.0009765  | 4.0        | [02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth)     | 38.11    | 51.67        |
| `STL-10`        | 1024 | 0.0078125  | 2.0        | [i7det4xq_0.0078125_1024_256_stl10_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/i7det4xq_0.0078125_1024_256_stl10_model.pth)     | 88.94     | 91.02        |

#### 3.2 ResNet-50
| Dataset        |  $d$   | $\lambda_{BT}$ | $\lambda_{reg}$ | Download Link to Pretrained Model | KNN Acc. | Linear Acc. |
| ----------     | ---  | ---------- | ---------- | ------------------------ | -------- | ----------- |
| `CIFAR-10`       | 1024 | 0.0078125  | 4.0        | [v3gwgusq_0.0078125_1024_256_cifar10_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/v3gwgusq_0.0078125_1024_256_cifar10_model.pth)     | 91.39     | 93.89        |
| `CIFAR-100`      | 1024 | 0.0078125  | 4.0        | [z6ngefw7_0.0078125_1024_256_cifar100_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/z6ngefw7_0.0078125_1024_256_cifar100_model.pth)     | 64.32     | 72.51        |
| `TinyImageNet`   | 1024 | 0.0009765  | 4.0        | [kxlkigsv_0.0009765_1024_256_tiny_imagenet_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/kxlkigsv_0.0009765_1024_256_tiny_imagenet_model.pth)     | 42.21     | 51.84        |
| `STL-10`        | 1024 | 0.0078125  | 2.0        | [pbknx38b_0.0078125_1024_256_stl10_model.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/pbknx38b_0.0078125_1024_256_stl10_model.pth)     | 87.79     | 91.70        |

**On `ImageNet`**
| # Epochs        |  $d$   | $\lambda_{BT}$ | $\lambda_{reg}$ | Download Link to Pretrained Model | Linear Acc. |
| ----------     | ---  | ---------- | ---------- | ------------------------ | ----------- |
| 300             | 8192 | 0.0051  | 0.0 (BT)        | [3on0l4wl_0.0000_8192_1024_imagenet_resnet50.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/3on0l4wl_0.0000_8192_1024_imagenet_resnet50.pth)   | 71.3        |
| 300             | 8192 | 0.0051  | 0.0025       | [l418b9zw_0.0025_8192_1024_imagenet_resnet50.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/l418b9zw_0.0025_8192_1024_imagenet_resnet50.pth)     |  70.9        |
| 300             | 8192 | 0.0051  | 0.1        | [13awtq23_0.1000_8192_1024_imagenet_resnet50.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/13awtq23_0.1000_8192_1024_imagenet_resnet50.pth)   | 71.6        |
| 300             | 8192 | 0.0051  | 1.0        | [3fb1op86_1.0000_8192_1024_imagenet_resnet50.pth](https://huggingface.co/wgcban/mix-bt/blob/main/checkpoints/3fb1op86_1.0000_8192_1024_imagenet_resnet50.pth)     | **72.2**        |
| 300             | 8192 | 0.0051  | 3.0        | [TBU]()     | TBU        |
| 300             | 8192 | 0.0051  | 5.0        | [TBU]()     | TBU        |

## 4 Training Statistics
Here we provide some training and validation (linear probing) statistics for Barlow Twins *vs.* Mixed Barlow Twins with `ResNet-50` backbone on `ImageNet`:

<img src="figs/in-loss-bt.png" width="256"/> <img src="figs/in-loss-reg.png" width="256"/> <img src="figs/in-linear.png" width="256"/> 

## 5 Disclaimer
A large portion of the code is from [Barlow Twins HSIC](https://github.com/yaohungt/Barlow-Twins-HSIC) (for experiments on small datasets: `CIFAR-10`, `CIFAR-100`, `TinyImageNet`, and `STL-10`) and official implementation of Barlow Twins [here](https://github.com/facebookresearch/barlowtwins) (for experiments on `ImageNet`), which is a great resource for academic development.

Also, note that the implementation of SOTA methods ([SimCLR](https://arxiv.org/abs/2002.05709), [BYOL](https://arxiv.org/abs/2006.07733), and [Witening-MSE](https://arxiv.org/abs/2007.06346)) in `ssl-sota` are copied from [Witening-MSE](https://github.com/htdt/self-supervised).

We would like to thank all of them for making their repositories publicly available for the research community. 🙏

## 6 Reference
If you feel our work is useful, please consider citing our work. Thanks!
```bibtex
@misc{bandara2023guarding,
      title={Guarding Barlow Twins Against Overfitting with Mixed Samples}, 
      author={Wele Gedara Chaminda Bandara and Celso M. De Melo and Vishal M. Patel},
      year={2023},
      eprint={2312.02151},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

## 7 License
This code is under MIT licence, you can find the complete file [here](https://github.com/wgcban/mix-bt/blob/main/LICENSE).