diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..15d0543cce6babec7fd594a74653b1d65f95dfaf 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+figs/in-loss-reg.png filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..22655ba5065b54e57670ea6ef465138cc91245a5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Wele Gedara Chaminda Bandara
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md
new file mode 100644
index 0000000000000000000000000000000000000000..f8a5479f28fd299fa88a5f2e951f85ece5738a4f
--- /dev/null
+++ b/MODEL_ZOO.md
@@ -0,0 +1,20 @@
+The following links provide pre-trained models:
+# ResNet-18 Pre-trained Models
+| Dataset | d | Lambda_BT | Lambda_Reg | Path to Pretrained Model | KNN Acc. | Linear Acc. |
+| ---------- | --- | ---------- | ---------- | ------------------------ | -------- | ----------- |
+| CIFAR-10 | 1024 | 0.0078125 | 4.0 | 4wdhbpcf_0.0078125_1024_256_cifar10_model.pth | 90.52 | 92.58 |
+| CIFAR-100 | 1024 | 0.0078125 | 4.0 | 76kk7scz_0.0078125_1024_256_cifar100_model.pth | 61.25 | 69.31 |
+| TinyImageNet | 1024 | 0.0009765 | 4.0 | 02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth | 38.11 | 51.67 |
+| STL-10 | 1024 | 0.0078125 | 2.0 | i7det4xq_0.0078125_1024_256_stl10_model.pth | 88.94 | 91.02 |
+
+
+
+# ResNet-50 Pre-trained Models
+| Dataset | d | Lambda_BT | Lambda_Reg | Path to Pretrained Model | KNN Acc. | Linear Acc. |
+| ---------- | --- | ---------- | ---------- | ------------------------ | -------- | ----------- |
+| CIFAR-10 | 1024 | 0.0078125 | 4.0 | v3gwgusq_0.0078125_1024_256_cifar10_model.pth | 91.39 | 93.89 |
+| CIFAR-100 | 1024 | 0.0078125 | 4.0 | z6ngefw7_0.0078125_1024_256_cifar100_model_2000.pth | 64.32 | 72.51 |
+| TinyImageNet | 1024 | 0.0009765 | 4.0 | kxlkigsv_0.0009765_1024_256_tiny_imagenet_model_2000.pth | 42.21 | 51.84 |
+| STL-10 | 1024 | 0.0078125 | 2.0 | pbknx38b_0.0078125_1024_256_stl10_model.pth | 87.79 | 91.70 |
+| ImageNet | 1024 | 0.0051 | 0.1 | 13awtq23_0.0051_8192_1024_imagenet_0.1_resnet50.pth | - | 72.1 |
+
diff --git a/README.md b/README.md
index 79e17098b2528fb1d64eb7ce558a0dcdb529ddf5..ea025e9477a50efa105be88e27a6f8d175614ee9 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,3 @@
----
-license: mit
----
# Mixed Barlow Twins
[**Guarding Barlow Twins Against Overfitting with Mixed Samples**](https://arxiv.org/abs/2312.02151)
diff --git a/augmentations/augmentations_cifar.py b/augmentations/augmentations_cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8ec6c8a30602612846409fdffb8c8297cb95da1
--- /dev/null
+++ b/augmentations/augmentations_cifar.py
@@ -0,0 +1,190 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base augmentations operators."""
+
+import numpy as np
+from PIL import Image, ImageOps, ImageEnhance
+
+# ImageNet code should change this value
+IMAGE_SIZE = 32
+import torch
+from torchvision import transforms
+
+
+def int_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval .
+
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+
+ Returns:
+ An int that results from scaling `maxval` according to `level`.
+ """
+ return int(level * maxval / 10)
+
+
+def float_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval.
+
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+
+ Returns:
+ A float that results from scaling `maxval` according to `level`.
+ """
+ return float(level) * maxval / 10.
+
+
+def sample_level(n):
+ return np.random.uniform(low=0.1, high=n)
+
+
+def autocontrast(pil_img, _):
+ return ImageOps.autocontrast(pil_img)
+
+
+def equalize(pil_img, _):
+ return ImageOps.equalize(pil_img)
+
+
+def posterize(pil_img, level):
+ level = int_parameter(sample_level(level), 4)
+ return ImageOps.posterize(pil_img, 4 - level)
+
+
+def rotate(pil_img, level):
+ degrees = int_parameter(sample_level(level), 30)
+ if np.random.uniform() > 0.5:
+ degrees = -degrees
+ return pil_img.rotate(degrees, resample=Image.BILINEAR)
+
+
+def solarize(pil_img, level):
+ level = int_parameter(sample_level(level), 256)
+ return ImageOps.solarize(pil_img, 256 - level)
+
+
+def shear_x(pil_img, level):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, level, 0, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def shear_y(pil_img, level):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, 0, level, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_x(pil_img, level):
+ level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, level, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_y(pil_img, level):
+ level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, 0, 0, 1, level),
+ resample=Image.BILINEAR)
+
+
+# operation that overlaps with ImageNet-C's test set
+def color(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Color(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def contrast(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Contrast(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def brightness(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Brightness(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def sharpness(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Sharpness(pil_img).enhance(level)
+
+def random_resized_crop(pil_img, level):
+ return transforms.RandomResizedCrop(32)(pil_img)
+
+def random_flip(pil_img, level):
+ return transforms.RandomHorizontalFlip(p=0.5)(pil_img)
+
+def grayscale(pil_img, level):
+ return transforms.Grayscale(num_output_channels=3)(pil_img)
+
+augmentations = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, grayscale #random_resized_crop, random_flip
+]
+
+augmentations_all = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, color, contrast, brightness, sharpness, grayscale #, random_resized_crop, random_flip
+]
+
+def aug_cifar(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3):
+ """Perform AugMix augmentations and compute mixture.
+
+ Args:
+ image: PIL.Image input image
+ preprocess: Preprocessing function which should return a torch tensor.
+
+ Returns:
+ mixed: Augmented and mixed image.
+ """
+ aug_list = augmentations_all
+ # if args.all_ops:
+ # aug_list = augmentations.augmentations_all
+
+ ws = np.float32(np.random.dirichlet([1] * mixture_width))
+ m = np.float32(np.random.beta(1, 1))
+
+ mix = torch.zeros_like(preprocess(image))
+ for i in range(mixture_width):
+ image_aug = image.copy()
+ depth = mixture_depth if mixture_depth > 0 else np.random.randint(
+ 1, 4)
+ for _ in range(depth):
+ op = np.random.choice(aug_list)
+ image_aug = op(image_aug, aug_severity)
+ # Preprocessing commutes since all coefficients are convex
+ mix += ws[i] * preprocess(image_aug)
+
+ # mixed = (1 - m) * preprocess(image) + m * mix
+ return mix
\ No newline at end of file
diff --git a/augmentations/augmentations_stl.py b/augmentations/augmentations_stl.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb62c61fe8036cfe6c006c2dc62fa6173cf45bc1
--- /dev/null
+++ b/augmentations/augmentations_stl.py
@@ -0,0 +1,190 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base augmentations operators."""
+
+import numpy as np
+from PIL import Image, ImageOps, ImageEnhance
+
+# ImageNet code should change this value
+IMAGE_SIZE = 64
+import torch
+from torchvision import transforms
+
+
+def int_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval .
+
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+
+ Returns:
+ An int that results from scaling `maxval` according to `level`.
+ """
+ return int(level * maxval / 10)
+
+
+def float_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval.
+
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+
+ Returns:
+ A float that results from scaling `maxval` according to `level`.
+ """
+ return float(level) * maxval / 10.
+
+
+def sample_level(n):
+ return np.random.uniform(low=0.1, high=n)
+
+
+def autocontrast(pil_img, _):
+ return ImageOps.autocontrast(pil_img)
+
+
+def equalize(pil_img, _):
+ return ImageOps.equalize(pil_img)
+
+
+def posterize(pil_img, level):
+ level = int_parameter(sample_level(level), 4)
+ return ImageOps.posterize(pil_img, 4 - level)
+
+
+def rotate(pil_img, level):
+ degrees = int_parameter(sample_level(level), 30)
+ if np.random.uniform() > 0.5:
+ degrees = -degrees
+ return pil_img.rotate(degrees, resample=Image.BILINEAR)
+
+
+def solarize(pil_img, level):
+ level = int_parameter(sample_level(level), 256)
+ return ImageOps.solarize(pil_img, 256 - level)
+
+
+def shear_x(pil_img, level):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, level, 0, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def shear_y(pil_img, level):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, 0, level, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_x(pil_img, level):
+ level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, level, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_y(pil_img, level):
+ level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, 0, 0, 1, level),
+ resample=Image.BILINEAR)
+
+
+# operation that overlaps with ImageNet-C's test set
+def color(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Color(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def contrast(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Contrast(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def brightness(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Brightness(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def sharpness(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Sharpness(pil_img).enhance(level)
+
+def random_resized_crop(pil_img, level):
+ return transforms.RandomResizedCrop(32)(pil_img)
+
+def random_flip(pil_img, level):
+ return transforms.RandomHorizontalFlip(p=0.5)(pil_img)
+
+def grayscale(pil_img, level):
+ return transforms.Grayscale(num_output_channels=3)(pil_img)
+
+augmentations = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, grayscale #random_resized_crop, random_flip
+]
+
+augmentations_all = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, color, contrast, brightness, sharpness, grayscale #, random_resized_crop, random_flip
+]
+
+def aug_stl(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3):
+ """Perform AugMix augmentations and compute mixture.
+
+ Args:
+ image: PIL.Image input image
+ preprocess: Preprocessing function which should return a torch tensor.
+
+ Returns:
+ mixed: Augmented and mixed image.
+ """
+ aug_list = augmentations
+ # if args.all_ops:
+ # aug_list = augmentations.augmentations_all
+
+ ws = np.float32(np.random.dirichlet([1] * mixture_width))
+ m = np.float32(np.random.beta(1, 1))
+
+ mix = torch.zeros_like(preprocess(image))
+ for i in range(mixture_width):
+ image_aug = image.copy()
+ depth = mixture_depth if mixture_depth > 0 else np.random.randint(
+ 1, 4)
+ for _ in range(depth):
+ op = np.random.choice(aug_list)
+ image_aug = op(image_aug, aug_severity)
+ # Preprocessing commutes since all coefficients are convex
+ mix += ws[i] * preprocess(image_aug)
+
+ mixed = (1 - m) * preprocess(image) + m * mix
+ return mixed
\ No newline at end of file
diff --git a/augmentations/augmentations_tiny.py b/augmentations/augmentations_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dde30244e7a4daa7c451fb5e3ad9b54e62a4404
--- /dev/null
+++ b/augmentations/augmentations_tiny.py
@@ -0,0 +1,190 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base augmentations operators."""
+
+import numpy as np
+from PIL import Image, ImageOps, ImageEnhance
+
+# ImageNet code should change this value
+IMAGE_SIZE = 64
+import torch
+from torchvision import transforms
+
+
+def int_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval .
+
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+
+ Returns:
+ An int that results from scaling `maxval` according to `level`.
+ """
+ return int(level * maxval / 10)
+
+
+def float_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval.
+
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+
+ Returns:
+ A float that results from scaling `maxval` according to `level`.
+ """
+ return float(level) * maxval / 10.
+
+
+def sample_level(n):
+ return np.random.uniform(low=0.1, high=n)
+
+
+def autocontrast(pil_img, _):
+ return ImageOps.autocontrast(pil_img)
+
+
+def equalize(pil_img, _):
+ return ImageOps.equalize(pil_img)
+
+
+def posterize(pil_img, level):
+ level = int_parameter(sample_level(level), 4)
+ return ImageOps.posterize(pil_img, 4 - level)
+
+
+def rotate(pil_img, level):
+ degrees = int_parameter(sample_level(level), 30)
+ if np.random.uniform() > 0.5:
+ degrees = -degrees
+ return pil_img.rotate(degrees, resample=Image.BILINEAR)
+
+
+def solarize(pil_img, level):
+ level = int_parameter(sample_level(level), 256)
+ return ImageOps.solarize(pil_img, 256 - level)
+
+
+def shear_x(pil_img, level):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, level, 0, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def shear_y(pil_img, level):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, 0, level, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_x(pil_img, level):
+ level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, level, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_y(pil_img, level):
+ level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
+ Image.AFFINE, (1, 0, 0, 0, 1, level),
+ resample=Image.BILINEAR)
+
+
+# operation that overlaps with ImageNet-C's test set
+def color(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Color(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def contrast(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Contrast(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def brightness(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Brightness(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def sharpness(pil_img, level):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Sharpness(pil_img).enhance(level)
+
+def random_resized_crop(pil_img, level):
+ return transforms.RandomResizedCrop(32)(pil_img)
+
+def random_flip(pil_img, level):
+ return transforms.RandomHorizontalFlip(p=0.5)(pil_img)
+
+def grayscale(pil_img, level):
+ return transforms.Grayscale(num_output_channels=3)(pil_img)
+
+augmentations = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, grayscale #random_resized_crop, random_flip
+]
+
+augmentations_all = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, color, contrast, brightness, sharpness, grayscale #, random_resized_crop, random_flip
+]
+
+def aug_tiny(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3):
+ """Perform AugMix augmentations and compute mixture.
+
+ Args:
+ image: PIL.Image input image
+ preprocess: Preprocessing function which should return a torch tensor.
+
+ Returns:
+ mixed: Augmented and mixed image.
+ """
+ aug_list = augmentations
+ # if args.all_ops:
+ # aug_list = augmentations.augmentations_all
+
+ ws = np.float32(np.random.dirichlet([1] * mixture_width))
+ m = np.float32(np.random.beta(1, 1))
+
+ mix = torch.zeros_like(preprocess(image))
+ for i in range(mixture_width):
+ image_aug = image.copy()
+ depth = mixture_depth if mixture_depth > 0 else np.random.randint(
+ 1, 4)
+ for _ in range(depth):
+ op = np.random.choice(aug_list)
+ image_aug = op(image_aug, aug_severity)
+ # Preprocessing commutes since all coefficients are convex
+ mix += ws[i] * preprocess(image_aug)
+
+ mixed = (1 - m) * preprocess(image) + m * mix
+ return mixed
\ No newline at end of file
diff --git a/data_statistics.py b/data_statistics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae212dc093d48169255e0b7ee3358d7219881c0c
--- /dev/null
+++ b/data_statistics.py
@@ -0,0 +1,61 @@
+def get_data_mean_and_stdev(dataset):
+ if dataset == 'CIFAR10' or dataset == 'CIFAR100':
+ mean = [0.5, 0.5, 0.5]
+ std = [0.5, 0.5, 0.5]
+ elif dataset == 'STL-10':
+ mean = [0.491, 0.482, 0.447]
+ std = [0.247, 0.244, 0.262]
+ elif dataset == 'ImageNet':
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ elif dataset == 'aircraft':
+ mean = [0.486, 0.507, 0.525]
+ std = [0.266, 0.260, 0.276]
+ elif dataset == 'cu_birds':
+ mean = [0.483, 0.491, 0.424]
+ std = [0.228, 0.224, 0.259]
+ elif dataset == 'dtd':
+ mean = [0.533, 0.474, 0.426]
+ std = [0.261, 0.250, 0.259]
+ elif dataset == 'fashionmnist':
+ mean = [0.348, 0.348, 0.348]
+ std = [0.347, 0.347, 0.347]
+ elif dataset == 'mnist':
+ mean = [0.170, 0.170, 0.170]
+ std = [0.320, 0.320, 0.320]
+ elif dataset == 'traffic_sign':
+ mean = [0.335, 0.291, 0.295]
+ std = [0.267, 0.249, 0.251]
+ elif dataset == 'vgg_flower':
+ mean = [0.518, 0.410, 0.329]
+ std = [0.296, 0.249, 0.285]
+ else:
+ raise Exception('Dataset %s not supported.'%dataset)
+ return mean, std
+
+def get_data_nclass(dataset):
+ if dataset == 'cifar10':
+ nclass = 10
+ elif dataset == 'cifar100cifar10':
+ nclass = 100
+ elif dataset == 'stl-10':
+ nclass = 10
+ elif dataset == 'ImageNet':
+ nclass = 1000
+ elif dataset == 'aircraft':
+ nclass = 102
+ elif dataset == 'cu_birds':
+ nclass = 200
+ elif dataset == 'dtd':
+ nclass = 47
+ elif dataset == 'fashionmnist':
+ nclass = 10
+ elif dataset == 'mnist':
+ nclass = 10
+ elif dataset == 'traffic_sign':
+ nclass = 43
+ elif dataset == 'vgg_flower':
+ nclass = 102
+ else:
+ raise Exception('Dataset %s not supported.'%dataset)
+ return nclass
\ No newline at end of file
diff --git a/download_imagenet.sh b/download_imagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7b0237336a36713a219486f3a4084d653a748898
--- /dev/null
+++ b/download_imagenet.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+# https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4
+cd /mnt/store/wbandar1/datasets
+wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar --no-check-certificate
+wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificate
+
+#
+# script to extract ImageNet dataset
+# ILSVRC2012_img_train.tar (about 138 GB)
+# ILSVRC2012_img_val.tar (about 6.3 GB)
+# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory
+#
+# https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md
+#
+# train/
+# ├── n01440764
+# │ ├── n01440764_10026.JPEG
+# │ ├── n01440764_10027.JPEG
+# │ ├── ......
+# ├── ......
+# val/
+# ├── n01440764
+# │ ├── ILSVRC2012_val_00000293.JPEG
+# │ ├── ILSVRC2012_val_00002138.JPEG
+# │ ├── ......
+# ├── ......
+#
+#
+# Extract the training data:
+#
+mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
+tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
+find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
+cd ..
+#
+# Extract the validation data and move images to subfolders:
+#
+mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
+wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
+#
+# Check total files after extract
+#
+# $ find train/ -name "*.JPEG" | wc -l
+# 1281167
+# $ find val/ -name "*.JPEG" | wc -l
+# 50000
+#
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c493d4a4a5ab8876cff5cde6a849db6f293b5cd1
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,188 @@
+name: ssl-aug
+channels:
+ - pytorch
+ - anaconda
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - bottleneck=1.3.4=py38hce1f21e_0
+ - brotlipy=0.7.0=py38h27cfd23_1003
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2022.6.15=ha878542_0
+ - cairo=1.16.0=hcf35c78_1003
+ - certifi=2022.6.15=py38h578d9bd_0
+ - cffi=1.15.0=py38h7f8727e_0
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - cryptography=37.0.1=py38h9ce1e76_0
+ - cudatoolkit=11.3.1=h2bc3f7f_2
+ - dataclasses=0.8=pyh6d0b6a4_7
+ - dbus=1.13.18=hb2f20db_0
+ - expat=2.4.8=h27087fc_0
+ - ffmpeg=4.3.2=hca11adc_0
+ - fontconfig=2.14.0=h8e229c2_0
+ - freetype=2.11.0=h70c0345_0
+ - fvcore=0.1.5.post20220512=pyhd8ed1ab_0
+ - gettext=0.19.8.1=hd7bead4_3
+ - gh=2.12.1=ha8f183a_0
+ - giflib=5.2.1=h7b6447c_0
+ - glib=2.66.3=h58526e2_0
+ - gmp=6.2.1=h295c915_3
+ - gnutls=3.6.15=he1e5248_0
+ - graphite2=1.3.14=h295c915_1
+ - gst-plugins-base=1.14.5=h0935bb2_2
+ - gstreamer=1.14.5=h36ae1b5_2
+ - harfbuzz=2.4.0=h9f30f68_3
+ - hdf5=1.10.6=hb1b8bf9_0
+ - icu=64.2=he1b5a44_1
+ - idna=3.3=pyhd3eb1b0_0
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - iopath=0.1.9=pyhd8ed1ab_0
+ - jasper=1.900.1=hd497a04_4
+ - jpeg=9e=h7f8727e_0
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - libblas=3.9.0=12_linux64_mkl
+ - libcblas=3.9.0=12_linux64_mkl
+ - libclang=9.0.1=default_hb4e5071_5
+ - libedit=3.1.20210910=h7f8727e_0
+ - libffi=3.2.1=hf484d3e_1007
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
+ - libgfortran4=7.5.0=ha8ba4b0_17
+ - libglib=2.66.3=hbe7bbb4_0
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.2=h7f8727e_0
+ - liblapack=3.9.0=12_linux64_mkl
+ - liblapacke=3.9.0=12_linux64_mkl
+ - libllvm9=9.0.1=h4a3c616_1
+ - libopencv=4.5.1=py38h703c3c0_0
+ - libpng=1.6.37=hbc83047_0
+ - libprotobuf=3.15.8=h780b84a_1
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.16.0=h27cfd23_0
+ - libtiff=4.2.0=h2818925_1
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=2.32.1=h7f98852_1000
+ - libuv=1.40.0=h7b6447c_0
+ - libwebp=1.2.2=h55f646e_0
+ - libwebp-base=1.2.2=h7f8727e_0
+ - libxcb=1.15=h7f8727e_0
+ - libxkbcommon=0.10.0=he1b5a44_0
+ - libxml2=2.9.9=hea5a465_1
+ - lz4-c=1.9.3=h295c915_1
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py38h7f8727e_0
+ - mkl_fft=1.3.1=py38hd3c417c_0
+ - mkl_random=1.2.2=py38h51133e4_0
+ - ncurses=6.3=h7f8727e_2
+ - nettle=3.7.3=hbbd107a_1
+ - nspr=4.33=h295c915_0
+ - nss=3.46.1=hab99668_0
+ - numexpr=2.8.1=py38h6abb31d_0
+ - numpy=1.22.3=py38he7a7128_0
+ - numpy-base=1.22.3=py38hf524024_0
+ - opencv=4.5.1=py38h578d9bd_0
+ - openh264=2.1.1=h4ff587b_0
+ - openssl=1.1.1o=h166bdaf_0
+ - packaging=21.3=pyhd3eb1b0_0
+ - pandas=1.4.2=py38h295c915_0
+ - pcre=8.45=h295c915_0
+ - pillow=9.0.1=py38h22f2fdc_0
+ - pip=21.2.4=py38h06a4308_0
+ - pixman=0.38.0=h7b6447c_0
+ - portalocker=2.3.0=py38h06a4308_0
+ - protobuf=3.15.8=py38h709712a_0
+ - py-opencv=4.5.1=py38h81c977d_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
+ - pyparsing=3.0.9=pyhd8ed1ab_0
+ - pysocks=1.7.1=py38h06a4308_0
+ - python=3.8.0=h0371630_2
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
+ - python_abi=3.8=2_cp38
+ - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
+ - pytorch-mutex=1.0=cuda
+ - pytz=2021.3=pyhd3eb1b0_0
+ - pyyaml=6.0=py38h7f8727e_1
+ - qt=5.12.5=hd8c4c69_1
+ - readline=7.0=h7b6447c_5
+ - requests=2.27.1=pyhd3eb1b0_0
+ - setuptools=61.2.0=py38h06a4308_0
+ - six=1.16.0=pyhd3eb1b0_1
+ - sqlite=3.33.0=h62c20be_0
+ - tabulate=0.8.9=py38h06a4308_0
+ - tensorboardx=2.5.1=pyhd8ed1ab_0
+ - termcolor=1.1.0=py38h06a4308_1
+ - tk=8.6.12=h1ccaba5_0
+ - torchvision=0.12.0=py38_cu113
+ - tqdm=4.64.0=py38h06a4308_0
+ - typing_extensions=4.1.1=pyh06a4308_0
+ - wheel=0.37.1=pyhd3eb1b0_0
+ - x264=1!161.3030=h7f98852_1
+ - xorg-kbproto=1.0.7=h7f98852_1002
+ - xorg-libice=1.0.10=h7f98852_0
+ - xorg-libsm=1.2.3=hd9c2040_1000
+ - xorg-libx11=1.7.2=h7f98852_0
+ - xorg-libxext=1.3.4=h7f98852_1
+ - xorg-libxrender=0.9.10=h7f98852_1003
+ - xorg-renderproto=0.11.1=h7f98852_1002
+ - xorg-xextproto=7.3.0=h7f98852_1002
+ - xorg-xproto=7.0.31=h27cfd23_1007
+ - xz=5.2.5=h7f8727e_1
+ - yacs=0.1.6=pyhd3eb1b0_1
+ - yaml=0.2.5=h7b6447c_0
+ - zip=3.0=h7f98852_1
+ - zlib=1.2.12=h7f8727e_2
+ - zstd=1.5.2=ha4553b6_0
+ - pip:
+ - absl-py==1.1.0
+ - appdirs==1.4.4
+ - cachetools==5.2.0
+ - click==8.1.7
+ - contourpy==1.0.6
+ - cycler==0.11.0
+ - decord==0.6.0
+ - deepspeed==0.5.8
+ - docker-pycreds==0.4.0
+ - einops==0.4.1
+ - filelock==3.7.1
+ - fonttools==4.38.0
+ - future==0.18.2
+ - gitdb==4.0.10
+ - gitpython==3.1.33
+ - google-auth==2.7.0
+ - google-auth-oauthlib==0.4.6
+ - grpcio==1.46.3
+ - hjson==3.0.2
+ - imageio==2.22.2
+ - importlib-metadata==4.11.4
+ - kiwisolver==1.4.4
+ - markdown==3.3.7
+ - matplotlib==3.6.1
+ - ninja==1.10.2.3
+ - oauthlib==3.2.0
+ - pathtools==0.1.2
+ - psutil==5.9.1
+ - pyasn1==0.4.8
+ - pyasn1-modules==0.2.8
+ - requests-oauthlib==1.3.1
+ - rsa==4.8
+ - scipy==1.9.0
+ - sentry-sdk==1.30.0
+ - setproctitle==1.3.2
+ - smmap==5.0.0
+ - tensorboard==2.9.1
+ - tensorboard-data-server==0.6.1
+ - tensorboard-plugin-wit==1.8.1
+ - thop==0.1.1-2209072238
+ - timm==0.4.12
+ - triton==1.1.1
+ - urllib3==1.26.16
+ - wandb==0.15.9
+ - werkzeug==2.1.2
+ - zipp==3.8.0
+prefix: /home/wbandar1/anaconda3/envs/ssl-aug
diff --git a/evaluate_imagenet.py b/evaluate_imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..59d72294a5e58759de9fbeebfda505db8d163462
--- /dev/null
+++ b/evaluate_imagenet.py
@@ -0,0 +1,289 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import argparse
+import json
+import os
+import random
+import signal
+import sys
+import time
+import urllib
+
+from torch import nn, optim
+from torchvision import models, datasets, transforms
+import torch
+import torchvision
+import wandb
+
+parser = argparse.ArgumentParser(description='Evaluate resnet50 features on ImageNet')
+parser.add_argument('data', type=Path, metavar='DIR',
+ help='path to dataset')
+parser.add_argument('pretrained', type=Path, metavar='FILE',
+ help='path to pretrained model')
+parser.add_argument('--weights', default='freeze', type=str,
+ choices=('finetune', 'freeze'),
+ help='finetune or freeze resnet weights')
+parser.add_argument('--train-percent', default=100, type=int,
+ choices=(100, 10, 1),
+ help='size of traing set in percent')
+parser.add_argument('--workers', default=8, type=int, metavar='N',
+ help='number of data loader workers')
+parser.add_argument('--epochs', default=100, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--batch-size', default=256, type=int, metavar='N',
+ help='mini-batch size')
+parser.add_argument('--lr-backbone', default=0.0, type=float, metavar='LR',
+ help='backbone base learning rate')
+parser.add_argument('--lr-classifier', default=0.3, type=float, metavar='LR',
+ help='classifier base learning rate')
+parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
+ help='weight decay')
+parser.add_argument('--print-freq', default=100, type=int, metavar='N',
+ help='print frequency')
+parser.add_argument('--checkpoint-dir', default='/mnt/store/wbandar1/projects/ssl-aug-artifacts/', type=Path,
+ metavar='DIR', help='path to checkpoint directory')
+
+
+def main():
+ args = parser.parse_args()
+ if args.train_percent in {1, 10}:
+ args.train_files = urllib.request.urlopen(f'https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/{args.train_percent}percent.txt').readlines()
+ args.ngpus_per_node = torch.cuda.device_count()
+ if 'SLURM_JOB_ID' in os.environ:
+ signal.signal(signal.SIGUSR1, handle_sigusr1)
+ signal.signal(signal.SIGTERM, handle_sigterm)
+ # single-node distributed training
+ args.rank = 0
+ args.dist_url = f'tcp://localhost:{random.randrange(49152, 65535)}'
+ args.world_size = args.ngpus_per_node
+ torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)
+
+
+def main_worker(gpu, args):
+ args.rank += gpu
+ torch.distributed.init_process_group(
+ backend='nccl', init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+
+ # initializing wandb
+ if args.rank == 0:
+ run = wandb.init(project="bt-in1k-eval", config=args, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/')
+ run_id = wandb.run.id
+ args.checkpoint_dir=Path(os.path.join(args.checkpoint_dir, run_id))
+
+ if args.rank == 0:
+ args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
+ print(' '.join(sys.argv))
+ print(' '.join(sys.argv), file=stats_file)
+
+ torch.cuda.set_device(gpu)
+ torch.backends.cudnn.benchmark = True
+
+ model = models.resnet50().cuda(gpu)
+ state_dict = torch.load(args.pretrained, map_location='cpu')
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ assert missing_keys == ['fc.weight', 'fc.bias'] and unexpected_keys == []
+ model.fc.weight.data.normal_(mean=0.0, std=0.01)
+ model.fc.bias.data.zero_()
+ if args.weights == 'freeze':
+ model.requires_grad_(False)
+ model.fc.requires_grad_(True)
+ classifier_parameters, model_parameters = [], []
+ for name, param in model.named_parameters():
+ if name in {'fc.weight', 'fc.bias'}:
+ classifier_parameters.append(param)
+ else:
+ model_parameters.append(param)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
+
+ criterion = nn.CrossEntropyLoss().cuda(gpu)
+
+ param_groups = [dict(params=classifier_parameters, lr=args.lr_classifier)]
+ if args.weights == 'finetune':
+ param_groups.append(dict(params=model_parameters, lr=args.lr_backbone))
+ optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
+
+ # automatically resume from checkpoint if it exists
+ if (args.checkpoint_dir / 'checkpoint.pth').is_file():
+ ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
+ map_location='cpu')
+ start_epoch = ckpt['epoch']
+ best_acc = ckpt['best_acc']
+ model.load_state_dict(ckpt['model'])
+ optimizer.load_state_dict(ckpt['optimizer'])
+ scheduler.load_state_dict(ckpt['scheduler'])
+ else:
+ start_epoch = 0
+ best_acc = argparse.Namespace(top1=0, top5=0)
+
+ # Data loading code
+ traindir = args.data / 'train'
+ valdir = args.data / 'val'
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+ train_dataset = datasets.ImageFolder(traindir, transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize,
+ ]))
+ val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ normalize,
+ ]))
+
+ if args.train_percent in {1, 10}:
+ train_dataset.samples = []
+ for fname in args.train_files:
+ fname = fname.decode().strip()
+ cls = fname.split('_')[0]
+ train_dataset.samples.append(
+ (traindir / cls / fname, train_dataset.class_to_idx[cls]))
+
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ kwargs = dict(batch_size=args.batch_size // args.world_size, num_workers=args.workers, pin_memory=True)
+ train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, **kwargs)
+ val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)
+
+ start_time = time.time()
+ for epoch in range(start_epoch, args.epochs):
+ # train
+ if args.weights == 'finetune':
+ model.train()
+ elif args.weights == 'freeze':
+ model.eval()
+ else:
+ assert False
+ train_sampler.set_epoch(epoch)
+ for step, (images, target) in enumerate(train_loader, start=epoch * len(train_loader)):
+ output = model(images.cuda(gpu, non_blocking=True))
+ loss = criterion(output, target.cuda(gpu, non_blocking=True))
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if step % args.print_freq == 0:
+ torch.distributed.reduce(loss.div_(args.world_size), 0)
+ if args.rank == 0:
+ pg = optimizer.param_groups
+ lr_classifier = pg[0]['lr']
+ lr_backbone = pg[1]['lr'] if len(pg) == 2 else 0
+ stats = dict(epoch=epoch, step=step, lr_backbone=lr_backbone,
+ lr_classifier=lr_classifier, loss=loss.item(),
+ time=int(time.time() - start_time))
+ print(json.dumps(stats))
+ print(json.dumps(stats), file=stats_file)
+ run.log(
+ {
+ "epoch": epoch,
+ "step": step,
+ "lr_backbone": lr_backbone,
+ "lr_classifier": lr_classifier,
+ "loss": loss.item(),
+ "time": int(time.time() - start_time),
+ }
+ )
+
+ # evaluate
+ model.eval()
+ if args.rank == 0:
+ top1 = AverageMeter('Acc@1')
+ top5 = AverageMeter('Acc@5')
+ with torch.no_grad():
+ for images, target in val_loader:
+ output = model(images.cuda(gpu, non_blocking=True))
+ acc1, acc5 = accuracy(output, target.cuda(gpu, non_blocking=True), topk=(1, 5))
+ top1.update(acc1[0].item(), images.size(0))
+ top5.update(acc5[0].item(), images.size(0))
+ best_acc.top1 = max(best_acc.top1, top1.avg)
+ best_acc.top5 = max(best_acc.top5, top5.avg)
+ stats = dict(epoch=epoch, acc1=top1.avg, acc5=top5.avg, best_acc1=best_acc.top1, best_acc5=best_acc.top5)
+ print(json.dumps(stats))
+ print(json.dumps(stats), file=stats_file)
+ run.log(
+ {
+ "epoch": epoch,
+ "eval_acc1": top1.avg,
+ "eval_acc5": top5.avg,
+ "eval_best_acc1": best_acc.top1,
+ "eval_best_acc5": best_acc.top5,
+ }
+ )
+
+ # sanity check
+ if args.weights == 'freeze':
+ reference_state_dict = torch.load(args.pretrained, map_location='cpu')
+ model_state_dict = model.module.state_dict()
+ for k in reference_state_dict:
+ assert torch.equal(model_state_dict[k].cpu(), reference_state_dict[k]), k
+
+ scheduler.step()
+ if args.rank == 0:
+ state = dict(
+ epoch=epoch + 1, best_acc=best_acc, model=model.state_dict(),
+ optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict())
+ torch.save(state, args.checkpoint_dir / 'checkpoint.pth')
+ wandb.finish()
+
+
+def handle_sigusr1(signum, frame):
+ os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}')
+ exit()
+
+
+def handle_sigterm(signum, frame):
+ pass
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self, name, fmt=':f'):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/evaluate_transfer.py b/evaluate_transfer.py
new file mode 100644
index 0000000000000000000000000000000000000000..18146e526027f4e521499ecac981b06f4e6d2350
--- /dev/null
+++ b/evaluate_transfer.py
@@ -0,0 +1,168 @@
+import argparse
+
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from thop import profile, clever_format
+from torch.utils.data import DataLoader
+from transfer_datasets import TRANSFER_DATASET
+import torchvision.transforms as transforms
+from data_statistics import get_data_mean_and_stdev, get_data_nclass
+from tqdm import tqdm
+
+import utils
+
+import wandb
+
+import torchvision
+
+def load_transform(dataset, size=32):
+ mean, std = get_data_mean_and_stdev(dataset)
+ transform = transforms.Compose([
+ transforms.Resize((size, size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=mean, std=std)])
+ return transform
+
+class Net(nn.Module):
+ def __init__(self, num_class, pretrained_path, dataset, arch):
+ super(Net, self).__init__()
+
+ if arch=='resnet18':
+ embedding_size = 512
+ elif arch=='resnet50':
+ embedding_size = 2048
+ else:
+ raise NotImplementedError
+
+ # encoder
+ from model import Model
+ self.f = Model(dataset=dataset, arch=arch).f
+ # classifier
+ self.fc = nn.Linear(embedding_size, num_class, bias=True)
+ self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
+
+ def forward(self, x):
+ x = self.f(x)
+ feature = torch.flatten(x, start_dim=1)
+ out = self.fc(feature)
+ return out
+
+# train or test for one epoch
+def train_val(net, data_loader, train_optimizer):
+ is_train = train_optimizer is not None
+ net.train() if is_train else net.eval()
+
+ total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
+ with (torch.enable_grad() if is_train else torch.no_grad()):
+ for data, target in data_bar:
+ data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
+ out = net(data)
+ loss = loss_criterion(out, target)
+
+ if is_train:
+ train_optimizer.zero_grad()
+ loss.backward()
+ train_optimizer.step()
+
+ total_num += data.size(0)
+ total_loss += loss.item() * data.size(0)
+ prediction = torch.argsort(out, dim=-1, descending=True)
+ total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
+ total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
+
+ data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}% model: {}'
+ .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num,
+ total_correct_1 / total_num * 100, total_correct_5 / total_num * 100,
+ model_path.split('/')[-1]))
+
+ return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Linear Evaluation')
+ parser.add_argument('--dataset', default='cifar10', type=str, help='Pre-trained dataset.', choices=['cifar10', 'cifar100', 'stl10', 'tiny_imagenet'])
+ parser.add_argument('--transfer_dataset', default='cifar10', type=str, help='Transfer dataset (i.e., testing dataset)', choices=['cifar10', 'cifar100', 'stl-10', 'aircraft', 'cu_birds', 'dtd', 'fashionmnist', 'mnist', 'traffic_sign', 'vgg_flower'])
+ parser.add_argument('--arch', default='resnet50', type=str, help='Backbone architecture for experiments', choices=['resnet50', 'resnet18'])
+ parser.add_argument('--model_path', type=str, default='results/Barlow_Twins/0.005_64_128_model.pth',
+ help='The base string of the pretrained model path')
+ parser.add_argument('--batch_size', type=int, default=128, help='Number of images in each mini-batch')
+ parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train')
+ parser.add_argument('--screen', type=str, help='screen session id')
+ # wandb related args
+ parser.add_argument('--wandb_group', type=str, help='group for wandb')
+
+ args = parser.parse_args()
+
+ wandb.init(project=f"Barlow-Twins-MixUp-TransferLearn-[{args.dataset}-to-X]-{args.arch}", config=args, dir='/data/wbandar1/projects/ssl-aug-artifacts/wandb_logs/', group=args.wandb_group, name=f'{args.transfer_dataset}')
+ run_id = wandb.run.id
+
+ model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs
+ dataset = args.dataset
+ transfer_dataset = args.transfer_dataset
+
+ if dataset in ['cifar10', 'cifar100']:
+ print("reshaping data into 32x32")
+ resize = 32
+ else:
+ print("reshaping data into 64x64")
+ resize = 64
+
+ train_data = TRANSFER_DATASET[args.transfer_dataset](train=True, image_transforms=load_transform(args.transfer_dataset, resize))
+ test_data = TRANSFER_DATASET[args.transfer_dataset](train=False, image_transforms=load_transform(args.transfer_dataset, resize))
+
+ train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
+ test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
+
+ model = Net(num_class=get_data_nclass(args.transfer_dataset), pretrained_path=model_path, dataset=dataset, arch=args.arch).cuda()
+ for param in model.f.parameters():
+ param.requires_grad = False
+
+ # optimizer with lr sheduler
+ # lr_start, lr_end = 1e-2, 1e-6
+ # gamma = (lr_end / lr_start) ** (1 / epochs)
+ # optimizer = optim.Adam(model.fc.parameters(), lr=lr_start, weight_decay=5e-6)
+ # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
+
+ # adpoted from
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 80], gamma=0.1)
+
+ # optimizer with no sheuduler
+ # optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)
+
+ loss_criterion = nn.CrossEntropyLoss()
+ results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [],
+ 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}
+
+ save_name = model_path.split('.pth')[0] + '_linear.csv'
+
+ best_acc = 0.0
+ for epoch in range(1, epochs + 1):
+ train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer)
+ results['train_loss'].append(train_loss)
+ results['train_acc@1'].append(train_acc_1)
+ results['train_acc@5'].append(train_acc_5)
+ test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None)
+ results['test_loss'].append(test_loss)
+ results['test_acc@1'].append(test_acc_1)
+ results['test_acc@5'].append(test_acc_5)
+ # save statistics
+ # data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
+ # data_frame.to_csv(save_name, index_label='epoch')
+ if test_acc_1 > best_acc:
+ best_acc = test_acc_1
+ wandb.log(
+ {
+ "train_loss": train_loss,
+ "train_acc@1": train_acc_1,
+ "train_acc@5": train_acc_5,
+ "test_loss": test_loss,
+ "test_acc@1": test_acc_1,
+ "test_acc@5": test_acc_5,
+ "best_acc": best_acc
+ }
+ )
+ scheduler.step()
+ wandb.finish()
diff --git a/figs/in-linear.png b/figs/in-linear.png
new file mode 100644
index 0000000000000000000000000000000000000000..609d71b77872f85a995b77109e4cd73ec89e1eb7
Binary files /dev/null and b/figs/in-linear.png differ
diff --git a/figs/in-loss-bt.png b/figs/in-loss-bt.png
new file mode 100644
index 0000000000000000000000000000000000000000..3c69f3023d39f453ffc78602ae9dc837a7af557e
Binary files /dev/null and b/figs/in-loss-bt.png differ
diff --git a/figs/in-loss-reg.png b/figs/in-loss-reg.png
new file mode 100644
index 0000000000000000000000000000000000000000..78df7c38cb303634f9ed55deb8d61f650026aa3d
--- /dev/null
+++ b/figs/in-loss-reg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab2e3e99017cd134a3f49878929bce151abcfa917cb8ceca436e401e2caeed4e
+size 1268898
diff --git a/figs/mix-bt.jpg b/figs/mix-bt.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8cae7967077e5182c809ec9333e50fe939dd8848
Binary files /dev/null and b/figs/mix-bt.jpg differ
diff --git a/figs/mix-bt.svg b/figs/mix-bt.svg
new file mode 100644
index 0000000000000000000000000000000000000000..7efd4e5e5a7b08969d0c28db44083d9d3dd288f7
--- /dev/null
+++ b/figs/mix-bt.svg
@@ -0,0 +1,3 @@
+
+
+
\ No newline at end of file
diff --git a/hubconf.py b/hubconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d45cbcd336cecc7775fd9f2ec0acbd11b3024d9
--- /dev/null
+++ b/hubconf.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torchvision.models.resnet import resnet50 as _resnet50
+
+dependencies = ['torch', 'torchvision']
+
+
+def resnet50(pretrained=True, **kwargs):
+ model = _resnet50(pretrained=False, **kwargs)
+ if pretrained:
+ url = 'https://dl.fbaipublicfiles.com/barlowtwins/ep1000_bs2048_lrw0.2_lrb0.0048_lambd0.0051/resnet50.pth'
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
+ model.load_state_dict(state_dict, strict=False)
+ return model
\ No newline at end of file
diff --git a/linear.py b/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..228110d04104a69629ec3df3c008d55479bb77cb
--- /dev/null
+++ b/linear.py
@@ -0,0 +1,166 @@
+import argparse
+
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from thop import profile, clever_format
+from torch.utils.data import DataLoader
+from torchvision.datasets import CIFAR10, CIFAR100
+from tqdm import tqdm
+
+import utils
+
+import wandb
+
+import torchvision
+
+class Net(nn.Module):
+ def __init__(self, num_class, pretrained_path, dataset, arch):
+ super(Net, self).__init__()
+
+ if arch=='resnet18':
+ embedding_size = 512
+ elif arch=='resnet50':
+ embedding_size = 2048
+ else:
+ raise NotImplementedError
+
+ # encoder
+ from model import Model
+ self.f = Model(dataset=dataset, arch=arch).f
+ # classifier
+ self.fc = nn.Linear(embedding_size, num_class, bias=True)
+ self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
+
+ def forward(self, x):
+ x = self.f(x)
+ feature = torch.flatten(x, start_dim=1)
+ out = self.fc(feature)
+ return out
+
+# train or test for one epoch
+def train_val(net, data_loader, train_optimizer):
+ is_train = train_optimizer is not None
+ net.train() if is_train else net.eval()
+
+ total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
+ with (torch.enable_grad() if is_train else torch.no_grad()):
+ for data, target in data_bar:
+ data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
+ out = net(data)
+ loss = loss_criterion(out, target)
+
+ if is_train:
+ train_optimizer.zero_grad()
+ loss.backward()
+ train_optimizer.step()
+
+ total_num += data.size(0)
+ total_loss += loss.item() * data.size(0)
+ prediction = torch.argsort(out, dim=-1, descending=True)
+ total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
+ total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
+
+ data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}% model: {}'
+ .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num,
+ total_correct_1 / total_num * 100, total_correct_5 / total_num * 100,
+ model_path.split('/')[-1]))
+
+ return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Linear Evaluation')
+ parser.add_argument('--dataset', default='cifar10', type=str, help='Dataset: cifar10 or tiny_imagenet or stl10')
+ parser.add_argument('--arch', default='resnet50', type=str, help='Backbone architecture for experiments', choices=['resnet50', 'resnet18'])
+ parser.add_argument('--model_path', type=str, default='results/Barlow_Twins/0.005_64_128_model.pth',
+ help='The base string of the pretrained model path')
+ parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch')
+ parser.add_argument('--epochs', type=int, default=200, help='Number of sweeps over the dataset to train')
+
+ args = parser.parse_args()
+
+ wandb.init(project=f"Barlow-Twins-MixUp-Linear-{args.dataset}-{args.arch}", config=args, dir='/data/wbandar1/projects/ssl-aug-artifacts/wandb_logs/')
+ run_id = wandb.run.id
+
+ model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs
+ dataset = args.dataset
+ if dataset == 'cifar10':
+ train_data = CIFAR10(root='data', train=True,\
+ transform=utils.CifarPairTransform(train_transform = True, pair_transform=False), download=True)
+ test_data = CIFAR10(root='data', train=False,\
+ transform=utils.CifarPairTransform(train_transform = False, pair_transform=False), download=True)
+ if dataset == 'cifar100':
+ train_data = CIFAR100(root='data', train=True,\
+ transform=utils.CifarPairTransform(train_transform = True, pair_transform=False), download=True)
+ test_data = CIFAR100(root='data', train=False,\
+ transform=utils.CifarPairTransform(train_transform = False, pair_transform=False), download=True)
+ elif dataset == 'stl10':
+ train_data = torchvision.datasets.STL10(root='data', split="train", \
+ transform=utils.StlPairTransform(train_transform = True, pair_transform=False), download=True)
+ test_data = torchvision.datasets.STL10(root='data', split="test", \
+ transform=utils.StlPairTransform(train_transform = False, pair_transform=False), download=True)
+ elif dataset == 'tiny_imagenet':
+ train_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/train', \
+ utils.TinyImageNetPairTransform(train_transform=True, pair_transform=False))
+ test_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/val', \
+ utils.TinyImageNetPairTransform(train_transform = False, pair_transform=False))
+
+ train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
+ test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
+
+ model = Net(num_class=len(train_data.classes), pretrained_path=model_path, dataset=dataset, arch=args.arch).cuda()
+ for param in model.f.parameters():
+ param.requires_grad = False
+
+ if dataset == 'cifar10' or dataset == 'cifar100':
+ flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
+ elif dataset == 'tiny_imagenet' or dataset == 'stl10':
+ flops, params = profile(model, inputs=(torch.randn(1, 3, 64, 64).cuda(),))
+ flops, params = clever_format([flops, params])
+ print('# Model Params: {} FLOPs: {}'.format(params, flops))
+
+ # optimizer with lr sheduler
+ lr_start, lr_end = 1e-2, 1e-6
+ gamma = (lr_end / lr_start) ** (1 / epochs)
+ optimizer = optim.Adam(model.fc.parameters(), lr=lr_start, weight_decay=5e-6)
+ scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
+
+ # optimizer with no sheuduler
+ # optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)
+
+ loss_criterion = nn.CrossEntropyLoss()
+ results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [],
+ 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}
+
+ save_name = model_path.split('.pth')[0] + '_linear.csv'
+
+ best_acc = 0.0
+ for epoch in range(1, epochs + 1):
+ train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer)
+ scheduler.step()
+ results['train_loss'].append(train_loss)
+ results['train_acc@1'].append(train_acc_1)
+ results['train_acc@5'].append(train_acc_5)
+ test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None)
+ results['test_loss'].append(test_loss)
+ results['test_acc@1'].append(test_acc_1)
+ results['test_acc@5'].append(test_acc_5)
+ # save statistics
+ # data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
+ # data_frame.to_csv(save_name, index_label='epoch')
+ #if test_acc_1 > best_acc:
+ # best_acc = test_acc_1
+ # torch.save(model.state_dict(), 'results/linear_model.pth')
+ wandb.log(
+ {
+ "train_loss": train_loss,
+ "train_acc@1": train_acc_1,
+ "train_acc@5": train_acc_5,
+ "test_loss": test_loss,
+ "test_acc@1": test_acc_1,
+ "test_acc@5": test_acc_5
+ }
+ )
+ wandb.finish()
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..695920d9951ba5dd15bd905953a3f97299551150
--- /dev/null
+++ b/main.py
@@ -0,0 +1,271 @@
+import argparse
+import os
+
+import pandas as pd
+import torch
+import numpy as np
+import torch.optim as optim
+import torch.nn.functional as F
+from thop import profile, clever_format
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts
+from tqdm import tqdm
+
+import utils
+from model import Model
+import math
+
+import torchvision
+
+import wandb
+
+if torch.cuda.is_available():
+ torch.backends.cudnn.benchmark = True
+
+def off_diagonal(x):
+ n, m = x.shape
+ assert n == m
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
+
+def adjust_learning_rate(args, optimizer, loader, step):
+ max_steps = args.epochs * len(loader)
+ warmup_steps = 10 * len(loader)
+ base_lr = args.batch_size / 256
+ if step < warmup_steps:
+ lr = base_lr * step / warmup_steps
+ else:
+ step -= warmup_steps
+ max_steps -= warmup_steps
+ q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+ end_lr = base_lr * 0.001
+ lr = base_lr * q + end_lr * (1 - q)
+ optimizer.param_groups[0]['lr'] = lr * args.lr
+
+def train(args, epoch, net, data_loader, train_optimizer):
+ net.train()
+ total_loss, total_loss_bt, total_loss_mix, total_num, train_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
+ for step, data_tuple in enumerate(train_bar, start=epoch * len(train_bar)):
+ if args.lr_shed == "cosine":
+ adjust_learning_rate(args, train_optimizer, data_loader, step)
+ (pos_1, pos_2), _ = data_tuple
+ pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True)
+ _, out_1 = net(pos_1)
+ _, out_2 = net(pos_2)
+
+ out_1_norm = (out_1 - out_1.mean(dim=0)) / out_1.std(dim=0)
+ out_2_norm = (out_2 - out_2.mean(dim=0)) / out_2.std(dim=0)
+ c = torch.matmul(out_1_norm.T, out_2_norm) / batch_size
+ on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
+ off_diag = off_diagonal(c).pow_(2).sum()
+ loss_bt = on_diag + lmbda * off_diag
+
+ ## MixUp (Our Contribution) ##
+ if args.is_mixup.lower() == 'true':
+ index = torch.randperm(batch_size).cuda(non_blocking=True)
+ alpha = np.random.beta(1.0, 1.0)
+ pos_m = alpha * pos_1 + (1 - alpha) * pos_2[index, :]
+
+ _, out_m = net(pos_m)
+ out_m_norm = (out_m - out_m.mean(dim=0)) / out_m.std(dim=0)
+
+ cc_m_1 = torch.matmul(out_m_norm.T, out_1_norm) / batch_size
+ cc_m_1_gt = alpha*torch.matmul(out_1_norm.T, out_1_norm) / batch_size + \
+ (1-alpha)*torch.matmul(out_2_norm[index,:].T, out_1_norm) / batch_size
+
+ cc_m_2 = torch.matmul(out_m_norm.T, out_2_norm) / batch_size
+ cc_m_2_gt = alpha*torch.matmul(out_1_norm.T, out_2_norm) / batch_size + \
+ (1-alpha)*torch.matmul(out_2_norm[index,:].T, out_2_norm) / batch_size
+
+ loss_mix = args.mixup_loss_scale*lmbda*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum())
+ else:
+ loss_mix = torch.zeros(1).cuda()
+ ## MixUp (Our Contribution) ##
+
+ loss = loss_bt + loss_mix
+ train_optimizer.zero_grad()
+ loss.backward()
+ train_optimizer.step()
+
+ total_num += batch_size
+ total_loss += loss.item() * batch_size
+ total_loss_bt += loss_bt.item() * batch_size
+ total_loss_mix += loss_mix.item() * batch_size
+
+ train_bar.set_description('Train Epoch: [{}/{}] lr: {:.3f}x10-3 Loss: {:.4f} lmbda:{:.4f} bsz:{} f_dim:{} dataset: {}'.format(\
+ epoch, epochs, train_optimizer.param_groups[0]['lr'] * 1000, total_loss / total_num, lmbda, batch_size, feature_dim, dataset))
+ return total_loss_bt / total_num, total_loss_mix / total_num, total_loss / total_num
+
+
+def test(net, memory_data_loader, test_data_loader):
+ net.eval()
+ total_top1, total_top5, total_num, feature_bank, target_bank = 0.0, 0.0, 0, [], []
+ with torch.no_grad():
+ # generate feature bank and target bank
+ for data_tuple in tqdm(memory_data_loader, desc='Feature extracting'):
+ (data, _), target = data_tuple
+ target_bank.append(target)
+ feature, out = net(data.cuda(non_blocking=True))
+ feature_bank.append(feature)
+ # [D, N]
+ feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
+ # [N]
+ feature_labels = torch.cat(target_bank, dim=0).contiguous().to(feature_bank.device)
+ # loop test data to predict the label by weighted knn search
+ test_bar = tqdm(test_data_loader)
+ for data_tuple in test_bar:
+ (data, _), target = data_tuple
+ data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
+ feature, out = net(data)
+
+ total_num += data.size(0)
+ # compute cos similarity between each feature vector and feature bank ---> [B, N]
+ sim_matrix = torch.mm(feature, feature_bank)
+ # [B, K]
+ sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
+ # [B, K]
+ sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
+ sim_weight = (sim_weight / temperature).exp()
+
+ # counts for each class
+ one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
+ # [B*K, C]
+ one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
+ # weighted score ---> [B, C]
+ pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)
+
+ pred_labels = pred_scores.argsort(dim=-1, descending=True)
+ total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
+ total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
+ test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
+ .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100))
+ return total_top1 / total_num * 100, total_top5 / total_num * 100
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Training Barlow Twins')
+ parser.add_argument('--dataset', default='cifar10', type=str, help='Dataset: cifar10, cifar100, tiny_imagenet, stl10', choices=['cifar10', 'cifar100', 'tiny_imagenet', 'stl10'])
+ parser.add_argument('--arch', default='resnet50', type=str, help='Backbone architecture', choices=['resnet50', 'resnet18'])
+ parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for embedding vector')
+ parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax (kNN evaluation)')
+ parser.add_argument('--k', default=200, type=int, help='Top k most similar images used to predict the label')
+ parser.add_argument('--batch_size', default=512, type=int, help='Number of images in each mini-batch')
+ parser.add_argument('--epochs', default=1000, type=int, help='Number of sweeps over the dataset to train')
+ parser.add_argument('--lr', default=1e-3, type=float, help='Base learning rate')
+ parser.add_argument('--lr_shed', default="step", choices=["step", "cosine"], type=str, help='Learning rate scheduler: step / cosine')
+
+ # for barlow twins
+ parser.add_argument('--lmbda', default=0.005, type=float, help='Lambda that controls the on- and off-diagonal terms')
+ parser.add_argument('--corr_neg_one', dest='corr_neg_one', action='store_true')
+ parser.add_argument('--corr_zero', dest='corr_neg_one', action='store_false')
+ parser.set_defaults(corr_neg_one=False)
+
+ # for mixup
+ parser.add_argument('--is_mixup', dest='is_mixup', type=str, default='false', choices=['true', 'false'])
+ parser.add_argument('--mixup_loss_scale', dest='mixup_loss_scale', type=float, default=5.0)
+
+ # GPU id (just for record)
+ parser.add_argument('--gpu', dest='gpu', type=int, default=0)
+
+ args = parser.parse_args()
+ is_mixup = args.is_mixup.lower() == 'true'
+
+ wandb.init(project=f"Barlow-Twins-MixUp-{args.dataset}-{args.arch}", config=args, dir='results/wandb_logs/')
+ run_id = wandb.run.id
+ dataset = args.dataset
+ feature_dim, temperature, k = args.feature_dim, args.temperature, args.k
+ batch_size, epochs = args.batch_size, args.epochs
+ lmbda = args.lmbda
+ corr_neg_one = args.corr_neg_one
+
+ if dataset == 'cifar10':
+ train_data = torchvision.datasets.CIFAR10(root='/data/wbandar1/datasets', train=True, \
+ transform=utils.CifarPairTransform(train_transform = True), download=True)
+ memory_data = torchvision.datasets.CIFAR10(root='/data/wbandar1/datasets', train=True, \
+ transform=utils.CifarPairTransform(train_transform = False), download=True)
+ test_data = torchvision.datasets.CIFAR10(root='/data/wbandar1/datasets', train=False, \
+ transform=utils.CifarPairTransform(train_transform = False), download=True)
+ elif dataset == 'cifar100':
+ train_data = torchvision.datasets.CIFAR100(root='/data/wbandar1/datasets', train=True, \
+ transform=utils.CifarPairTransform(train_transform = True), download=True)
+ memory_data = torchvision.datasets.CIFAR100(root='/data/wbandar1/datasets', train=True, \
+ transform=utils.CifarPairTransform(train_transform = False), download=True)
+ test_data = torchvision.datasets.CIFAR100(root='/data/wbandar1/datasets', train=False, \
+ transform=utils.CifarPairTransform(train_transform = False), download=True)
+ elif dataset == 'stl10':
+ train_data = torchvision.datasets.STL10(root='/data/wbandar1/datasets', split="train+unlabeled", \
+ transform=utils.StlPairTransform(train_transform = True), download=True)
+ memory_data = torchvision.datasets.STL10(root='/data/wbandar1/datasets', split="train", \
+ transform=utils.StlPairTransform(train_transform = False), download=True)
+ test_data = torchvision.datasets.STL10(root='/data/wbandar1/datasets', split="test", \
+ transform=utils.StlPairTransform(train_transform = False), download=True)
+ elif dataset == 'tiny_imagenet':
+ # download if not exits
+ if not os.path.isdir('/data/wbandar1/datasets/tiny-imagenet-200'):
+ raise ValueError("First preprocess the tinyimagenet dataset...")
+
+ train_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/train', \
+ utils.TinyImageNetPairTransform(train_transform = True))
+ memory_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/train', \
+ utils.TinyImageNetPairTransform(train_transform = False))
+ test_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/val', \
+ utils.TinyImageNetPairTransform(train_transform = False))
+
+ train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True,
+ drop_last=True)
+ memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
+ test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
+
+ # model setup and optimizer config
+ model = Model(feature_dim, dataset, args.arch).cuda()
+ if dataset == 'cifar10' or dataset == 'cifar100':
+ flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
+ elif dataset == 'tiny_imagenet' or dataset == 'stl10':
+ flops, params = profile(model, inputs=(torch.randn(1, 3, 64, 64).cuda(),))
+ flops, params = clever_format([flops, params])
+ print('# Model Params: {} FLOPs: {}'.format(params, flops))
+
+ optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
+ if args.lr_shed == "step":
+ m = [args.epochs - a for a in [50, 25]]
+ scheduler = MultiStepLR(optimizer, milestones=m, gamma=0.2)
+ c = len(memory_data.classes)
+
+ results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []}
+ save_name_pre = '{}_{}_{}_{}_{}'.format(run_id, lmbda, feature_dim, batch_size, dataset)
+ run_id_dir = os.path.join('results/', run_id)
+ if not os.path.exists(run_id_dir):
+ print('Creating directory {}'.format(run_id_dir))
+ os.mkdir(run_id_dir)
+
+ best_acc = 0.0
+ for epoch in range(1, epochs + 1):
+ loss_bt, loss_mix, train_loss = train(args, epoch, model, train_loader, optimizer)
+ if args.lr_shed == "step":
+ scheduler.step()
+ wandb.log(
+ {
+ "epoch": epoch,
+ "lr": optimizer.param_groups[0]['lr'],
+ "loss_bt": loss_bt,
+ "loss_mix": loss_mix,
+ "train_loss": train_loss}
+ )
+ if epoch % 5 == 0:
+ test_acc_1, test_acc_5 = test(model, memory_loader, test_loader)
+
+ results['train_loss'].append(train_loss)
+ results['test_acc@1'].append(test_acc_1)
+ results['test_acc@5'].append(test_acc_5)
+ data_frame = pd.DataFrame(data=results, index=range(5, epoch + 1, 5))
+ data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
+ wandb.log(
+ {
+ "test_acc@1": test_acc_1,
+ "test_acc@5": test_acc_5
+ }
+ )
+ if test_acc_1 > best_acc:
+ best_acc = test_acc_1
+ torch.save(model.state_dict(), 'results/{}/{}_model.pth'.format(run_id, save_name_pre))
+ if epoch % 50 == 0:
+ torch.save(model.state_dict(), 'results/{}/{}_model_{}.pth'.format(run_id, save_name_pre, epoch))
+ wandb.finish()
diff --git a/main_imagenet.py b/main_imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..085545c612bdc7ab4341215c5f769a6c0427b3b1
--- /dev/null
+++ b/main_imagenet.py
@@ -0,0 +1,463 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import argparse
+import json
+import math
+import os
+import random
+import signal
+import subprocess
+import sys
+import time
+import numpy as np
+import wandb
+
+from PIL import Image, ImageOps, ImageFilter
+from torch import nn, optim
+import torch
+import torchvision
+import torchvision.transforms as transforms
+
+parser = argparse.ArgumentParser(description='Barlow Twins Training')
+parser.add_argument('data', type=Path, metavar='DIR',
+ help='path to dataset')
+parser.add_argument('--workers', default=8, type=int, metavar='N',
+ help='number of data loader workers')
+parser.add_argument('--epochs', default=300, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--batch-size', default=512, type=int, metavar='N',
+ help='mini-batch size')
+parser.add_argument('--learning-rate-weights', default=0.2, type=float, metavar='LR',
+ help='base learning rate for weights')
+parser.add_argument('--learning-rate-biases', default=0.0048, type=float, metavar='LR',
+ help='base learning rate for biases and batch norm parameters')
+parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
+ help='weight decay')
+parser.add_argument('--lambd', default=0.0051, type=float, metavar='L',
+ help='weight on off-diagonal terms')
+parser.add_argument('--projector', default='8192-8192-8192', type=str,
+ metavar='MLP', help='projector MLP')
+parser.add_argument('--print-freq', default=1, type=int, metavar='N',
+ help='print frequency')
+parser.add_argument('--checkpoint-dir', default='/mnt/store/wbandar1/projects/ssl-aug-artifacts/', type=Path,
+ metavar='DIR', help='path to checkpoint directory')
+parser.add_argument('--is_mixup', default='false', type=str,
+ metavar='L', help='mixup regularization', choices=['true', 'false'])
+parser.add_argument('--lambda_mixup', default=0.1, type=float, metavar='L',
+ help='Hyperparamter for the regularization loss')
+
+def main():
+ args = parser.parse_args()
+ args.is_mixup = args.is_mixup.lower() == 'true'
+ args.ngpus_per_node = torch.cuda.device_count()
+
+ run = wandb.init(project="Barlow-Twins-MixUp-ImageNet", config=args, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/')
+ run_id = wandb.run.id
+ args.checkpoint_dir=Path(os.path.join(args.checkpoint_dir, run_id))
+
+ if 'SLURM_JOB_ID' in os.environ:
+ # single-node and multi-node distributed training on SLURM cluster
+ # requeue job on SLURM preemption
+ signal.signal(signal.SIGUSR1, handle_sigusr1)
+ signal.signal(signal.SIGTERM, handle_sigterm)
+ # find a common host name on all nodes
+ # assume scontrol returns hosts in the same order on all nodes
+ cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST')
+ stdout = subprocess.check_output(cmd.split())
+ host_name = stdout.decode().splitlines()[0]
+ args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node
+ args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node
+ args.dist_url = f'tcp://{host_name}:58472'
+ else:
+ # single-node distributed training
+ args.rank = 0
+ args.dist_url = 'tcp://localhost:58472'
+ args.world_size = args.ngpus_per_node
+ torch.multiprocessing.spawn(main_worker, (args,run,), args.ngpus_per_node)
+ wandb.finish()
+
+
+def main_worker(gpu, args, run):
+ args.rank += gpu
+ torch.distributed.init_process_group(
+ backend='nccl', init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+
+ if args.rank == 0:
+ args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
+ print(' '.join(sys.argv))
+ print(' '.join(sys.argv), file=stats_file)
+
+ torch.cuda.set_device(gpu)
+ torch.backends.cudnn.benchmark = True
+
+ model = BarlowTwins(args).cuda(gpu)
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ param_weights = []
+ param_biases = []
+ for param in model.parameters():
+ if param.ndim == 1:
+ param_biases.append(param)
+ else:
+ param_weights.append(param)
+ parameters = [{'params': param_weights}, {'params': param_biases}]
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
+ optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay,
+ weight_decay_filter=True,
+ lars_adaptation_filter=True)
+
+ # automatically resume from checkpoint if it exists
+ if (args.checkpoint_dir / 'checkpoint.pth').is_file():
+ ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
+ map_location='cpu')
+ start_epoch = ckpt['epoch']
+ model.load_state_dict(ckpt['model'])
+ optimizer.load_state_dict(ckpt['optimizer'])
+ else:
+ start_epoch = 0
+
+ dataset = torchvision.datasets.ImageFolder(args.data / 'train', Transform())
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+ assert args.batch_size % args.world_size == 0
+ per_device_batch_size = args.batch_size // args.world_size
+ loader = torch.utils.data.DataLoader(
+ dataset, batch_size=per_device_batch_size, num_workers=args.workers,
+ pin_memory=True, sampler=sampler)
+
+ start_time = time.time()
+ scaler = torch.cuda.amp.GradScaler(growth_interval=100, enabled=True)
+ for epoch in range(start_epoch, args.epochs):
+ sampler.set_epoch(epoch)
+ for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
+ y1 = y1.cuda(gpu, non_blocking=True)
+ y2 = y2.cuda(gpu, non_blocking=True)
+ adjust_learning_rate(args, optimizer, loader, step)
+ mixup_loss_scale = adjust_mixup_scale(loader, step, args.lambda_mixup)
+ optimizer.zero_grad()
+ with torch.cuda.amp.autocast(enabled=True):
+ loss_bt, loss_reg = model(y1, y2, args.is_mixup)
+ loss_regs = mixup_loss_scale * loss_reg
+ loss = loss_bt + loss_regs
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ if step % args.print_freq == 0:
+ if args.rank == 0:
+ stats = dict(epoch=epoch, step=step,
+ lr_weights=optimizer.param_groups[0]['lr'],
+ lr_biases=optimizer.param_groups[1]['lr'],
+ loss=loss.item(),
+ time=int(time.time() - start_time))
+ print(json.dumps(stats))
+ print(json.dumps(stats), file=stats_file)
+ if args.is_mixup:
+ run.log(
+ {
+ "epoch": epoch,
+ "step": step,
+ "lr_weights": optimizer.param_groups[0]['lr'],
+ "lr_biases": optimizer.param_groups[1]['lr'],
+ "loss": loss.item(),
+ "loss_bt": loss_bt.item(),
+ "loss_reg(unscaled)": loss_reg.item(),
+ "reg_scale": mixup_loss_scale,
+ "loss_reg(scaled)": loss_regs.item(),
+ "time": int(time.time() - start_time)}
+ )
+ else:
+ run.log(
+ {
+ "epoch": epoch,
+ "step": step,
+ "lr_weights": optimizer.param_groups[0]['lr'],
+ "lr_biases": optimizer.param_groups[1]['lr'],
+ "loss": loss.item(),
+ "loss_bt": loss.item(),
+ "loss_reg(unscaled)": 0.,
+ "reg_scale": 0.,
+ "loss_reg(scaled)": 0.,
+ "time": int(time.time() - start_time)}
+ )
+ if args.rank == 0:
+ # save checkpoint
+ state = dict(epoch=epoch + 1, model=model.state_dict(),
+ optimizer=optimizer.state_dict())
+ torch.save(state, args.checkpoint_dir / 'checkpoint.pth')
+ if args.rank == 0:
+ # save final model
+ print("Saving final model ...")
+ torch.save(model.module.backbone.state_dict(),
+ args.checkpoint_dir / 'resnet50.pth')
+ print("Finished saving final model ...")
+
+
+def adjust_learning_rate(args, optimizer, loader, step):
+ max_steps = args.epochs * len(loader)
+ warmup_steps = 10 * len(loader)
+ base_lr = args.batch_size / 256
+ if step < warmup_steps:
+ lr = base_lr * step / warmup_steps
+ else:
+ step -= warmup_steps
+ max_steps -= warmup_steps
+ q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+ end_lr = base_lr * 0.001
+ lr = base_lr * q + end_lr * (1 - q)
+ optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
+ optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases
+
+def adjust_mixup_scale(loader, step, lambda_mixup):
+ warmup_steps = 10 * len(loader)
+ if step < warmup_steps:
+ return lambda_mixup * step / warmup_steps
+ else:
+ return lambda_mixup
+
+def handle_sigusr1(signum, frame):
+ os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}')
+ exit()
+
+
+def handle_sigterm(signum, frame):
+ pass
+
+
+def off_diagonal(x):
+ # return a flattened view of the off-diagonal elements of a square matrix
+ n, m = x.shape
+ assert n == m
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
+
+
+class BarlowTwins(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ self.backbone = torchvision.models.resnet50(zero_init_residual=True)
+ self.backbone.fc = nn.Identity()
+
+ # projector
+ sizes = [2048] + list(map(int, args.projector.split('-')))
+ layers = []
+ for i in range(len(sizes) - 2):
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
+ layers.append(nn.BatchNorm1d(sizes[i + 1]))
+ layers.append(nn.ReLU(inplace=True))
+ layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
+ self.projector = nn.Sequential(*layers)
+
+ # normalization layer for the representations z1 and z2
+ # self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
+
+ # def forward(self, y1, y2):
+ # z1 = self.projector(self.backbone(y1))
+ # z2 = self.projector(self.backbone(y2))
+
+ # # empirical cross-correlation matrix
+ # c = self.bn(z1).T @ self.bn(z2)
+
+ # # sum the cross-correlation matrix between all gpus
+ # c.div_(self.args.batch_size)
+ # torch.distributed.all_reduce(c)
+
+ # on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
+ # off_diag = off_diagonal(c).pow_(2).sum()
+ # loss = on_diag + self.args.lambd * off_diag
+ # return loss
+
+ def forward(self, y1, y2, is_mixup):
+ batch_size = y1.shape[0]
+
+ ### original barlow twins ###
+ z1 = self.projector(self.backbone(y1))
+ z2 = self.projector(self.backbone(y2))
+
+ # normilization
+ z1 = (z1 - z1.mean(dim=0)) / z1.std(dim=0)
+ z2 = (z2 - z2.mean(dim=0)) / z2.std(dim=0)
+
+ # empirical cross-correlation matrix
+ c = z1.T @ z2
+
+ # sum the cross-correlation matrix between all gpus
+ c.div_(self.args.batch_size)
+ torch.distributed.all_reduce(c)
+
+ on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
+ off_diag = off_diagonal(c).pow_(2).sum()
+ loss = on_diag + self.args.lambd * off_diag
+
+ if is_mixup:
+ ##############################################
+ ### mixup regularization: Implementation 1 ###
+ ##############################################
+
+ # index = torch.randperm(batch_size).cuda(non_blocking=True)
+ # alpha = np.random.beta(1.0, 1.0)
+ # ym = alpha * y1 + (1. - alpha) * y2[index, :]
+ # zm = self.projector(self.backbone(ym))
+
+ # # normilization
+ # zm = (zm - zm.mean(dim=0)) / zm.std(dim=0)
+
+ # # cc
+ # cc_m_1 = zm.T @ z1
+ # cc_m_1.div_(batch_size)
+ # cc_m_1_gt = alpha*(z1.T @ z1) + (1.-alpha)*(z2[index,:].T @ z1)
+ # cc_m_1_gt.div_(batch_size)
+
+ # cc_m_2 = zm.T @ z2
+ # cc_m_2.div_(batch_size)
+ # cc_m_2_gt = alpha*(z2.T @ z2) + (1.-alpha)*(z2[index,:].T @ z2)
+ # cc_m_2_gt.div_(batch_size)
+
+ # # mixup reg. loss
+ # lossm = 0.5*self.args.lambd*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum())
+
+ ##############################################
+ ### mixup regularization: Implementation 2 ###
+ ##############################################
+ index = torch.randperm(batch_size).cuda(non_blocking=True)
+ alpha = np.random.beta(1.0, 1.0)
+ ym = alpha * y1 + (1. - alpha) * y2[index, :]
+ zm = self.projector(self.backbone(ym))
+
+ # normilization
+ zm = (zm - zm.mean(dim=0)) / zm.std(dim=0)
+
+ # cc
+ cc_m_1 = zm.T @ z1
+ cc_m_1.div_(self.args.batch_size)
+ cc_m_1_gt = alpha*(z1.T @ z1) + (1.-alpha)*(z2[index,:].T @ z1)
+ cc_m_1_gt.div_(self.args.batch_size)
+
+ cc_m_2 = zm.T @ z2
+ cc_m_2.div_(self.args.batch_size)
+ cc_m_2_gt = alpha*(z2.T @ z2) + (1.-alpha)*(z2[index,:].T @ z2)
+ cc_m_2_gt.div_(self.args.batch_size)
+
+ # gathering all cc
+ torch.distributed.all_reduce(cc_m_1)
+ torch.distributed.all_reduce(cc_m_1_gt)
+ torch.distributed.all_reduce(cc_m_2)
+ torch.distributed.all_reduce(cc_m_2_gt)
+
+ # mixup reg. loss
+ lossm = 0.5*self.args.lambd*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum())
+ else:
+ lossm = torch.zeros(1)
+ return loss, lossm
+
+class LARS(optim.Optimizer):
+ def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
+ weight_decay_filter=False, lars_adaptation_filter=False):
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
+ eta=eta, weight_decay_filter=weight_decay_filter,
+ lars_adaptation_filter=lars_adaptation_filter)
+ super().__init__(params, defaults)
+
+
+ def exclude_bias_and_norm(self, p):
+ return p.ndim == 1
+
+ @torch.no_grad()
+ def step(self):
+ for g in self.param_groups:
+ for p in g['params']:
+ dp = p.grad
+
+ if dp is None:
+ continue
+
+ if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
+ dp = dp.add(p, alpha=g['weight_decay'])
+
+ if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
+ param_norm = torch.norm(p)
+ update_norm = torch.norm(dp)
+ one = torch.ones_like(param_norm)
+ q = torch.where(param_norm > 0.,
+ torch.where(update_norm > 0,
+ (g['eta'] * param_norm / update_norm), one), one)
+ dp = dp.mul(q)
+
+ param_state = self.state[p]
+ if 'mu' not in param_state:
+ param_state['mu'] = torch.zeros_like(p)
+ mu = param_state['mu']
+ mu.mul_(g['momentum']).add_(dp)
+
+ p.add_(mu, alpha=-g['lr'])
+
+
+
+class GaussianBlur(object):
+ def __init__(self, p):
+ self.p = p
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ sigma = random.random() * 1.9 + 0.1
+ return img.filter(ImageFilter.GaussianBlur(sigma))
+ else:
+ return img
+
+
+class Solarization(object):
+ def __init__(self, p):
+ self.p = p
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ return ImageOps.solarize(img)
+ else:
+ return img
+
+
+class Transform:
+ def __init__(self):
+ self.transform = transforms.Compose([
+ transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomApply(
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4,
+ saturation=0.2, hue=0.1)],
+ p=0.8
+ ),
+ transforms.RandomGrayscale(p=0.2),
+ GaussianBlur(p=1.0),
+ Solarization(p=0.0),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.transform_prime = transforms.Compose([
+ transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomApply(
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4,
+ saturation=0.2, hue=0.1)],
+ p=0.8
+ ),
+ transforms.RandomGrayscale(p=0.2),
+ GaussianBlur(p=0.1),
+ Solarization(p=0.2),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+
+ def __call__(self, x):
+ y1 = self.transform(x)
+ y2 = self.transform_prime(x)
+ return y1, y2
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3cd4165bdce3853b8fe42557a0f22bce8554132
--- /dev/null
+++ b/model.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.models.resnet import resnet50, resnet18
+
+
+class Model(nn.Module):
+ def __init__(self, feature_dim=128, dataset='cifar10', arch='resnet50'):
+ super(Model, self).__init__()
+
+ self.f = []
+ if arch == 'resnet18':
+ temp_model = resnet18().named_children()
+ embedding_size = 512
+ elif arch == 'resnet50':
+ temp_model = resnet50().named_children()
+ embedding_size = 2048
+ else:
+ raise NotImplementedError
+
+ for name, module in temp_model:
+ if name == 'conv1':
+ module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+ if dataset == 'cifar10' or dataset == 'cifar100':
+ if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
+ self.f.append(module)
+ elif dataset == 'tiny_imagenet' or dataset == 'stl10':
+ if not isinstance(module, nn.Linear):
+ self.f.append(module)
+ # encoder
+ self.f = nn.Sequential(*self.f)
+ # projection head
+ self.g = nn.Sequential(nn.Linear(embedding_size, 512, bias=False), nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
+
+ def forward(self, x):
+ x = self.f(x)
+ feature = torch.flatten(x, start_dim=1)
+ out = self.g(feature)
+ return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
diff --git a/preprocess_datasets/preprocess_tinyimagenet.sh b/preprocess_datasets/preprocess_tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cdbe83cc60f8d9d05763d39804376b8c5446eba7
--- /dev/null
+++ b/preprocess_datasets/preprocess_tinyimagenet.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+# download and unzip dataset
+cd /data/wbandar1/datasets
+wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
+unzip tiny-imagenet-200.zip
+
+current="$(pwd)/tiny-imagenet-200"
+
+# training data
+cd $current/train
+for DIR in $(ls); do
+ cd $DIR
+ rm *.txt
+ mv images/* .
+ rm -r images
+ cd ..
+done
+
+# validation data
+cd $current/val
+annotate_file="val_annotations.txt"
+length=$(cat $annotate_file | wc -l)
+for i in $(seq 1 $length); do
+ # fetch i th line
+ line=$(sed -n ${i}p $annotate_file)
+ # get file name and directory name
+ file=$(echo $line | cut -f1 -d" " )
+ directory=$(echo $line | cut -f2 -d" ")
+ mkdir -p $directory
+ mv images/$file $directory
+done
+rm -r images
+echo "done"
\ No newline at end of file
diff --git a/scripts-linear-resnet18/cifar10.sh b/scripts-linear-resnet18/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9b2b573b6b386a9ad2a322e8eacc10012d9456e9
--- /dev/null
+++ b/scripts-linear-resnet18/cifar10.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=cifar10
+arch=resnet18
+batch_size=512
+model_path=checkpoints/4wdhbpcf_0.0078125_1024_256_cifar10_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet18/cifar100.sh b/scripts-linear-resnet18/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5ab2adcef428cf40981dd7390e93affced82026c
--- /dev/null
+++ b/scripts-linear-resnet18/cifar100.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=cifar100
+arch=resnet18
+batch_size=512
+model_path=checkpoints/76kk7scz_0.0078125_1024_256_cifar100_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-sug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet18/stl10.sh b/scripts-linear-resnet18/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..92bdff0e86b9312186142632dda2eea4835453e4
--- /dev/null
+++ b/scripts-linear-resnet18/stl10.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=stl10
+arch=resnet18
+batch_size=512
+model_path=checkpoints/i7det4xq_0.0078125_1024_256_stl10_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-sug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet18/tinyimagenet.sh b/scripts-linear-resnet18/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f14e204eece17499542f6c657c83697b0e6a2d11
--- /dev/null
+++ b/scripts-linear-resnet18/tinyimagenet.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=tiny_imagenet
+arch=resnet18
+batch_size=512
+model_path=checkpoints/02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-sug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet50/cifar10.sh b/scripts-linear-resnet50/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..79b2b086c04960a94de1e0313fdd059a5872f078
--- /dev/null
+++ b/scripts-linear-resnet50/cifar10.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=cifar10
+arch=resnet50
+batch_size=512
+model_path=checkpoints/v3gwgusq_0.0078125_1024_256_cifar10_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet50/cifar100.sh b/scripts-linear-resnet50/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..92a265541ed38ac68adb456b3cfbf4f905bdcd2a
--- /dev/null
+++ b/scripts-linear-resnet50/cifar100.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=cifar100
+arch=resnet50
+batch_size=512
+model_path=checkpoints/z6ngefw7_0.0078125_1024_256_cifar100_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet50/imagenet_sup.sh b/scripts-linear-resnet50/imagenet_sup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b4f5a45ff423eb5ed7917563e799112d34873a0c
--- /dev/null
+++ b/scripts-linear-resnet50/imagenet_sup.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+path_to_imagenet_data=datasets/imagenet1k/
+path_to_model=checkpoints/13awtq23_0.0051_8192_1024_imagenet_0.1_resnet50.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python evaluate_imagenet.py ${path_to_imagenet_data} ${path_to_model} --lr-classifier 0.3^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet50/stl10.sh b/scripts-linear-resnet50/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b73cf9ef17a88b0815c03e9fb645269822001
--- /dev/null
+++ b/scripts-linear-resnet50/stl10.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=stl10
+arch=resnet50
+batch_size=512
+model_path=checkpoints/pbknx38b_0.0078125_1024_256_stl10_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-linear-resnet50/tinyimagenet.sh b/scripts-linear-resnet50/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b05f986baae3984c23173a9bb910f2f289385fbf
--- /dev/null
+++ b/scripts-linear-resnet50/tinyimagenet.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=0
+dataset=tiny_imagenet
+arch=resnet50
+batch_size=512
+model_path=checkpoints/kxlkigsv_0.0009765_1024_256_tiny_imagenet_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-pretrain-resnet18/cifar10.sh b/scripts-pretrain-resnet18/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8b62fe4a45191e83106e6e6ed1b94a27bc63f670
--- /dev/null
+++ b/scripts-pretrain-resnet18/cifar10.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+# default: https://wandb.ai/cha-yas/Barlow-Twins-MixUp-cifar10-resnet18/runs/4wdhbpcf/overview?workspace=user-wgcban
+gpu=0
+dataset=cifar10
+arch=resnet18
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+lmbda=0.0078125
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet18/cifar100.sh b/scripts-pretrain-resnet18/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..35f81e0629a4c96869adfb20744dee877e440e71
--- /dev/null
+++ b/scripts-pretrain-resnet18/cifar100.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=cifar100
+arch=resnet18
+feature_dim=2048
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+lmbda=0.0078125
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet18/stl10.sh b/scripts-pretrain-resnet18/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..02e21d439cd2b28277942f59d2a0289d55a678c1
--- /dev/null
+++ b/scripts-pretrain-resnet18/stl10.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=stl10
+arch=resnet18
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=2.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625pochs=2000
+lmbda=0.0078125
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet18/tinyimagenet.sh b/scripts-pretrain-resnet18/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d279973f9696d626c4f6737507f87ec9c2d702e0
--- /dev/null
+++ b/scripts-pretrain-resnet18/tinyimagenet.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=tiny_imagenet
+arch=resnet18
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda
+lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc)
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet50/cifar10.sh b/scripts-pretrain-resnet50/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2019fe16d6cf20d0e8d7b0b7d78134cb9c663de6
--- /dev/null
+++ b/scripts-pretrain-resnet50/cifar10.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=cifar10
+arch=resnet50
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=1000
+lr=0.01
+lr_shed=cosine # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+lmbda=0.0078125
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet50/cifar100.sh b/scripts-pretrain-resnet50/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2d06e1191dc5214047c9693eb052d2ee58f2357a
--- /dev/null
+++ b/scripts-pretrain-resnet50/cifar100.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=cifar100
+arch=resnet50
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=1000
+lr=0.01
+lr_shed=cosine # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+lmbda=0.0078125
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet50/imagenet.sh b/scripts-pretrain-resnet50/imagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..742ba06e7fdf4d9cb15516ee3ced55126aeb5177
--- /dev/null
+++ b/scripts-pretrain-resnet50/imagenet.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+is_mixup=true
+batch_size=1024 #128/gpu works
+lr_w=0.2 #0.2
+lr_b=0.0048 #0.0048
+lambda_mixup=1.0
+
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main_imagenet.py /data/wbandar1/datasets/imagenet1k/ --is_mixup ${is_mixup} --batch-size ${batch_size} --learning-rate-weights ${lr_w} --learning-rate-biases ${lr_b} --lambda_mixup ${lambda_mixup}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/scripts-pretrain-resnet50/stl10.sh b/scripts-pretrain-resnet50/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7a9da099e851737cfb7cb3ff45b0f9f97fb7acad
--- /dev/null
+++ b/scripts-pretrain-resnet50/stl10.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=stl10
+arch=resnet50
+feature_dim=4096
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine # step, cosine
+mixup_loss_scale=2.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+lmbda=0.0078125
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-pretrain-resnet50/tinyimagenet.sh b/scripts-pretrain-resnet50/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f9abde5f52d5db3c7f66fb0c9a3e0e1c96a08f11
--- /dev/null
+++ b/scripts-pretrain-resnet50/tinyimagenet.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+gpu=0
+dataset=tiny_imagenet
+arch=resnet50
+feature_dim=4096
+is_mixup=false # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda
+lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc)
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-transfer-resnet18/cifar10-to-x.sh b/scripts-transfer-resnet18/cifar10-to-x.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eff60ca6732af6aa9661adaf94cf7b78f77f349a
--- /dev/null
+++ b/scripts-transfer-resnet18/cifar10-to-x.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+gpu=0
+dataset=cifar10
+arch=resnet18
+batch_size=128
+wandb_group='best-mbt'
+model_path=checkpoints/4wdhbpcf_0.0078125_1024_256_cifar10_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+transfer_dataset='dtd'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='mnist'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='fashionmnist'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='cu_birds'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='vgg_flower'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='traffic_sign'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='aircraft'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-transfer-resnet18/cifar100-to-x.sh b/scripts-transfer-resnet18/cifar100-to-x.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8822ec7c62e11718772c9ea57708d2b735990b01
--- /dev/null
+++ b/scripts-transfer-resnet18/cifar100-to-x.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+gpu=0
+dataset=cifar100
+arch=resnet18
+batch_size=128
+wandb_group='mbt'
+model_path=checkpoints/76kk7scz_0.0078125_1024_256_cifar100_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+transfer_dataset='dtd'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='mnist'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='fashionmnist'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='cu_birds'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='vgg_flower'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='traffic_sign'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='aircraft'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/scripts-transfer-resnet18/stl10-to-x-bt.sh b/scripts-transfer-resnet18/stl10-to-x-bt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..15e90293bd5f205f1f6bd22817d1d817e7052674
--- /dev/null
+++ b/scripts-transfer-resnet18/stl10-to-x-bt.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+gpu=0
+dataset=stl10
+arch=resnet18
+batch_size=128
+wandb_group='mbt'
+model_path=checkpoints/i7det4xq_0.0078125_1024_256_stl10_model.pth
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+transfer_dataset='dtd'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='mnist'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='fashionmnist'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='cu_birds'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='vgg_flower'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='traffic_sign'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+transfer_dataset='aircraft'
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/setup.sh b/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9ac4c3ae340199ff54cef3755d8636b25126bb54
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+mkdir ssl-aug
+mkdir Barlow-Twins-HSIC
+mkdir /data/wbandar1/projects/ssl-aug-artifacts/results
+git clone https://github.com/wgcban/ssl-aug.git Barlow-Twins-HSIC
+
+cd Barlow-Twins-HSIC
+conda env create -f environment.yml
+conda activate ssl-aug
+
+
diff --git a/ssl-sota/README.md b/ssl-sota/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..da2ef91c8428dc6c8bf241e73dec8834319225b6
--- /dev/null
+++ b/ssl-sota/README.md
@@ -0,0 +1,87 @@
+# Self-Supervised Representation Learning
+
+Official repository of the paper **Whitening for Self-Supervised Representation Learning**
+
+ICML 2021 | [arXiv:2007.06346](https://arxiv.org/abs/2007.06346)
+
+It includes 3 types of losses:
+- W-MSE [arXiv](https://arxiv.org/abs/2007.06346)
+- Contrastive [SimCLR arXiv](https://arxiv.org/abs/2002.05709)
+- BYOL [arXiv](https://arxiv.org/abs/2006.07733)
+
+And 5 datasets:
+- CIFAR-10 and CIFAR-100
+- STL-10
+- Tiny ImageNet
+- ImageNet-100
+Checkpoints are stored in `data` each 100 epochs during training.
+
+The implementation is optimized for a single GPU, although multiple are also supported. It includes fast evaluation: we pre-compute embeddings for the entire dataset and then train a classifier on top. The evaluation of the ResNet-18 encoder takes about one minute.
+
+## Installation
+
+The implementation is based on PyTorch. Logging works on [wandb.ai](https://wandb.ai/). See `docker/Dockerfile`.
+
+#### ImageNet-100
+To get this dataset, take the original ImageNet and filter out [this subset of classes](https://github.com/HobbitLong/CMC/blob/master/imagenet100.txt). We do not use augmentations during testing, and loading big images with resizing on the fly is slow, so we can preprocess classifier train and test images. We recommend [mogrify](https://imagemagick.org/script/mogrify.php) for it. First, you need to resize to 256 (just like `torchvision.transforms.Resize(256)`) and then crop to 224 (like `torchvision.transforms.CenterCrop(224)`). Finally, put the original images to `train`, and resized to `clf` and `test`.
+
+## Usage
+
+Detailed settings are good by default, to see all options:
+```
+python -m train --help
+python -m test --help
+```
+
+To reproduce the results from [table 1](https://arxiv.org/abs/2007.06346):
+#### W-MSE 4
+```
+python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --num_samples 4 --bs 256 --emb 64 --w_size 128
+python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --num_samples 4 --bs 256 --emb 64 --w_size 128
+python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256
+python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256
+```
+
+#### W-MSE 2
+```
+python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --w_size 128
+python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --w_size 128
+python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --w_size 256 --w_iter 4
+python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --w_size 256 --w_iter 4
+```
+
+#### Contrastive
+```
+python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --method contrastive --arch resnet50
+python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --method contrastive --arch resnet50
+python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --method contrastive --arch resnet50
+python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --method contrastive --arch resnet50
+```
+
+#### BYOL
+```
+python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --method byol
+python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --method byol
+python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --method byol
+python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --method byol
+```
+
+#### ImageNet-100
+```
+python -m train --dataset imagenet --epoch 240 --lr 2e-3 --emb 128 --w_size 256 --crop_s0 0.08 --cj0 0.8 --cj1 0.8 --cj2 0.8 --cj3 0.2 --gs_p 0.2
+python -m train --dataset imagenet --epoch 240 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 --crop_s0 0.08 --cj0 0.8 --cj1 0.8 --cj2 0.8 --cj3 0.2 --gs_p 0.2
+```
+
+Use `--no_norm` to disable normalization (for Euclidean distance).
+
+## Citation
+```
+@inproceedings{ermolov2021whitening,
+ title={Whitening for self-supervised representation learning},
+ author={Ermolov, Aleksandr and Siarohin, Aliaksandr and Sangineto, Enver and Sebe, Nicu},
+ booktitle={International Conference on Machine Learning},
+ pages={3015--3024},
+ year={2021},
+ organization={PMLR}
+}
+```
diff --git a/ssl-sota/cfg.py b/ssl-sota/cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a920cde219f5003f26e06e3c0f3e0375aaf284d
--- /dev/null
+++ b/ssl-sota/cfg.py
@@ -0,0 +1,152 @@
+from functools import partial
+import argparse
+from torchvision import models
+import multiprocessing
+from datasets import DS_LIST
+from methods import METHOD_LIST
+
+
+def get_cfg():
+ """ generates configuration from user input in console """
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument(
+ "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type",
+ )
+ parser.add_argument(
+ "--wandb",
+ type=str,
+ default="ssl-sota",
+ help="name of the project for logging at https://wandb.ai",
+ )
+ parser.add_argument(
+ "--byol_tau", type=float, default=0.99, help="starting tau for byol loss"
+ )
+ parser.add_argument(
+ "--num_samples",
+ type=int,
+ default=2,
+ help="number of samples (d) generated from each image",
+ )
+
+ addf = partial(parser.add_argument, type=float)
+ addf("--cj0", default=0.4, help="color jitter brightness")
+ addf("--cj1", default=0.4, help="color jitter contrast")
+ addf("--cj2", default=0.4, help="color jitter saturation")
+ addf("--cj3", default=0.1, help="color jitter hue")
+ addf("--cj_p", default=0.8, help="color jitter probability")
+ addf("--gs_p", default=0.1, help="grayscale probability")
+ addf("--crop_s0", default=0.2, help="crop size from")
+ addf("--crop_s1", default=1.0, help="crop size to")
+ addf("--crop_r0", default=0.75, help="crop ratio from")
+ addf("--crop_r1", default=(4 / 3), help="crop ratio to")
+ addf("--hf_p", default=0.5, help="horizontal flip probability")
+
+ parser.add_argument(
+ "--no_lr_warmup",
+ dest="lr_warmup",
+ action="store_false",
+ help="do not use learning rate warmup",
+ )
+ parser.add_argument(
+ "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head"
+ )
+ parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier")
+ parser.add_argument("--fname", type=str, help="load model from file")
+ parser.add_argument(
+ "--lr_step",
+ type=str,
+ choices=["cos", "step", "none"],
+ default="step",
+ help="learning rate schedule type",
+ )
+ parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
+ parser.add_argument(
+ "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)"
+ )
+ parser.add_argument(
+ "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)"
+ )
+ parser.add_argument("--T0", type=int, help="period (for --lr_step cos)")
+ parser.add_argument(
+ "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)"
+ )
+ parser.add_argument(
+ "--w_eps", type=float, default=1e-4, help="eps for stability for whitening"
+ )
+ parser.add_argument(
+ "--head_layers", type=int, default=2, help="number of FC layers in head"
+ )
+ parser.add_argument(
+ "--head_size", type=int, default=1024, help="size of FC layers in head"
+ )
+
+ parser.add_argument(
+ "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss"
+ )
+ parser.add_argument(
+ "--w_iter",
+ type=int,
+ default=1,
+ help="iterations for whitening matrix estimation",
+ )
+
+ parser.add_argument(
+ "--no_norm", dest="norm", action="store_false", help="don't normalize latents",
+ )
+ parser.add_argument(
+ "--tau", type=float, default=0.5, help="contrastive loss temperature"
+ )
+
+ parser.add_argument("--epoch", type=int, default=200, help="total epoch number")
+ parser.add_argument(
+ "--eval_every_drop",
+ type=int,
+ default=5,
+ help="how often to evaluate after learning rate drop",
+ )
+ parser.add_argument(
+ "--eval_every", type=int, default=20, help="how often to evaluate"
+ )
+ parser.add_argument("--emb", type=int, default=64, help="embedding size")
+ parser.add_argument(
+ "--bs", type=int, default=384, help="number of original images in batch N",
+ )
+ parser.add_argument(
+ "--drop",
+ type=int,
+ nargs="*",
+ default=[50, 25],
+ help="milestones for learning rate decay (0 = last epoch)",
+ )
+ parser.add_argument(
+ "--drop_gamma",
+ type=float,
+ default=0.2,
+ help="multiplicative factor of learning rate decay",
+ )
+ parser.add_argument(
+ "--arch",
+ type=str,
+ choices=[x for x in dir(models) if "resn" in x],
+ default="resnet18",
+ help="encoder architecture",
+ )
+ parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10")
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=0,
+ help="dataset workers number",
+ )
+ parser.add_argument(
+ "--clf",
+ type=str,
+ default="sgd",
+ choices=["sgd", "knn", "lbfgs"],
+ help="classifier for test.py",
+ )
+ parser.add_argument(
+ "--eval_head", action="store_true", help="eval head output instead of model",
+ )
+ parser.add_argument("--imagenet_path", type=str, default="~/IN100/")
+ return parser.parse_args()
diff --git a/ssl-sota/datasets/__init__.py b/ssl-sota/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a72e00b827c4cbba69c3f69f6d7d8ed513bfb660
--- /dev/null
+++ b/ssl-sota/datasets/__init__.py
@@ -0,0 +1,22 @@
+from .cifar10 import CIFAR10
+from .cifar100 import CIFAR100
+from .stl10 import STL10
+from .tiny_in import TinyImageNet
+from .imagenet import ImageNet
+
+
+DS_LIST = ["cifar10", "cifar100", "stl10", "tinyimagenet", "imagenet"]
+
+
+def get_ds(name):
+ assert name in DS_LIST
+ if name == "cifar10":
+ return CIFAR10
+ elif name == "cifar100":
+ return CIFAR100
+ elif name == "stl10":
+ return STL10
+ elif name == "tinyimagenet":
+ return TinyImageNet
+ elif name == "imagenet":
+ return ImageNet
diff --git a/ssl-sota/datasets/base.py b/ssl-sota/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a34efb94d7b3d4146b279b88e1bf57cbf9bbea
--- /dev/null
+++ b/ssl-sota/datasets/base.py
@@ -0,0 +1,67 @@
+from abc import ABCMeta, abstractmethod
+from functools import lru_cache
+from torch.utils.data import DataLoader
+
+
+class BaseDataset(metaclass=ABCMeta):
+ """
+ base class for datasets, it includes 3 types:
+ - for self-supervised training,
+ - for classifier training for evaluation,
+ - for testing
+ """
+
+ def __init__(
+ self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000,
+ ):
+ self.aug_cfg = aug_cfg
+ self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test
+ self.num_workers = num_workers
+
+ @abstractmethod
+ def ds_train(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def ds_clf(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def ds_test(self):
+ raise NotImplementedError
+
+ @property
+ @lru_cache()
+ def train(self):
+ return DataLoader(
+ dataset=self.ds_train(),
+ batch_size=self.bs_train,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ @property
+ @lru_cache()
+ def clf(self):
+ return DataLoader(
+ dataset=self.ds_clf(),
+ batch_size=self.bs_clf,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ @property
+ @lru_cache()
+ def test(self):
+ return DataLoader(
+ dataset=self.ds_test(),
+ batch_size=self.bs_test,
+ shuffle=False,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ drop_last=False,
+ )
diff --git a/ssl-sota/datasets/cifar10.py b/ssl-sota/datasets/cifar10.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef8aafae9087b611381467a16132825896a8b77f
--- /dev/null
+++ b/ssl-sota/datasets/cifar10.py
@@ -0,0 +1,26 @@
+from torchvision.datasets import CIFAR10 as C10
+import torchvision.transforms as T
+from .transforms import MultiSample, aug_transform
+from .base import BaseDataset
+
+
+def base_transform():
+ return T.Compose(
+ [T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
+ )
+
+
+class CIFAR10(BaseDataset):
+ def ds_train(self):
+ t = MultiSample(
+ aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
+ )
+ return C10(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t)
+
+ def ds_clf(self):
+ t = base_transform()
+ return C10(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t)
+
+ def ds_test(self):
+ t = base_transform()
+ return C10(root="/mnt/store/wbandar1/datasets/", train=False, download=True, transform=t)
diff --git a/ssl-sota/datasets/cifar100.py b/ssl-sota/datasets/cifar100.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e4173e705ac4b9443768713a9025b108ccc155
--- /dev/null
+++ b/ssl-sota/datasets/cifar100.py
@@ -0,0 +1,26 @@
+from torchvision.datasets import CIFAR100 as C100
+import torchvision.transforms as T
+from .transforms import MultiSample, aug_transform
+from .base import BaseDataset
+
+
+def base_transform():
+ return T.Compose(
+ [T.ToTensor(), T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]
+ )
+
+
+class CIFAR100(BaseDataset):
+ def ds_train(self):
+ t = MultiSample(
+ aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
+ )
+ return C100(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t,)
+
+ def ds_clf(self):
+ t = base_transform()
+ return C100(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t)
+
+ def ds_test(self):
+ t = base_transform()
+ return C100(root="/mnt/store/wbandar1/datasets/", train=False, download=True, transform=t)
diff --git a/ssl-sota/datasets/imagenet.py b/ssl-sota/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5677db1d4f69c9eb0a3f12dd7ff45a4982061e49
--- /dev/null
+++ b/ssl-sota/datasets/imagenet.py
@@ -0,0 +1,41 @@
+import random
+from torchvision.datasets import ImageFolder
+import torchvision.transforms as T
+from PIL import ImageFilter
+from .transforms import MultiSample, aug_transform
+from .base import BaseDataset
+
+
+class RandomBlur:
+ def __init__(self, r0, r1):
+ self.r0, self.r1 = r0, r1
+
+ def __call__(self, image):
+ r = random.uniform(self.r0, self.r1)
+ return image.filter(ImageFilter.GaussianBlur(radius=r))
+
+
+def base_transform():
+ return T.Compose(
+ [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
+ )
+
+
+class ImageNet(BaseDataset):
+ def ds_train(self):
+ aug_with_blur = aug_transform(
+ 224,
+ base_transform,
+ self.aug_cfg,
+ extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)],
+ )
+ t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples)
+ return ImageFolder(root=self.aug_cfg.imagenet_path + "train", transform=t)
+
+ def ds_clf(self):
+ t = base_transform()
+ return ImageFolder(root=self.aug_cfg.imagenet_path + "clf", transform=t)
+
+ def ds_test(self):
+ t = base_transform()
+ return ImageFolder(root=self.aug_cfg.imagenet_path + "test", transform=t)
diff --git a/ssl-sota/datasets/stl10.py b/ssl-sota/datasets/stl10.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f2c50f5cd242bb58652fbf67f15e79a9423aba
--- /dev/null
+++ b/ssl-sota/datasets/stl10.py
@@ -0,0 +1,32 @@
+from torchvision.datasets import STL10 as S10
+import torchvision.transforms as T
+from .transforms import MultiSample, aug_transform
+from .base import BaseDataset
+
+
+def base_transform():
+ return T.Compose(
+ [T.ToTensor(), T.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))]
+ )
+
+
+def test_transform():
+ return T.Compose(
+ [T.Resize(70, interpolation=3), T.CenterCrop(64), base_transform()]
+ )
+
+
+class STL10(BaseDataset):
+ def ds_train(self):
+ t = MultiSample(
+ aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
+ )
+ return S10(root="/mnt/store/wbandar1/datasets/", split="train+unlabeled", download=True, transform=t)
+
+ def ds_clf(self):
+ t = test_transform()
+ return S10(root="/mnt/store/wbandar1/datasets/", split="train", download=True, transform=t)
+
+ def ds_test(self):
+ t = test_transform()
+ return S10(root="/mnt/store/wbandar1/datasets/", split="test", download=True, transform=t)
diff --git a/ssl-sota/datasets/tiny_in.py b/ssl-sota/datasets/tiny_in.py
new file mode 100644
index 0000000000000000000000000000000000000000..42bc9ae0b819c5ce9fbaffffb545abaf4253aad9
--- /dev/null
+++ b/ssl-sota/datasets/tiny_in.py
@@ -0,0 +1,26 @@
+from torchvision.datasets import ImageFolder
+import torchvision.transforms as T
+from .transforms import MultiSample, aug_transform
+from .base import BaseDataset
+
+
+def base_transform():
+ return T.Compose(
+ [T.ToTensor(), T.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))]
+ )
+
+
+class TinyImageNet(BaseDataset):
+ def ds_train(self):
+ t = MultiSample(
+ aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
+ )
+ return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/train", transform=t)
+
+ def ds_clf(self):
+ t = base_transform()
+ return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/train", transform=t)
+
+ def ds_test(self):
+ t = base_transform()
+ return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/val", transform=t)
diff --git a/ssl-sota/datasets/transforms.py b/ssl-sota/datasets/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..743deb7024276f8510b6ee7db9e2898c933a43e6
--- /dev/null
+++ b/ssl-sota/datasets/transforms.py
@@ -0,0 +1,33 @@
+import torchvision.transforms as T
+
+
+def aug_transform(crop, base_transform, cfg, extra_t=[]):
+ """ augmentation transform generated from config """
+ return T.Compose(
+ [
+ T.RandomApply(
+ [T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p
+ ),
+ T.RandomGrayscale(p=cfg.gs_p),
+ T.RandomResizedCrop(
+ crop,
+ scale=(cfg.crop_s0, cfg.crop_s1),
+ ratio=(cfg.crop_r0, cfg.crop_r1),
+ interpolation=3,
+ ),
+ T.RandomHorizontalFlip(p=cfg.hf_p),
+ *extra_t,
+ base_transform(),
+ ]
+ )
+
+
+class MultiSample:
+ """ generates n samples with augmentation """
+
+ def __init__(self, transform, n=2):
+ self.transform = transform
+ self.num = n
+
+ def __call__(self, x):
+ return tuple(self.transform(x) for _ in range(self.num))
diff --git a/ssl-sota/docker/Dockerfile b/ssl-sota/docker/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..32c18303107f1ac4c563576df4a428f853b7230d
--- /dev/null
+++ b/ssl-sota/docker/Dockerfile
@@ -0,0 +1,6 @@
+FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
+RUN pip install sklearn opencv-python
+RUN pip install matplotlib
+RUN pip install wandb
+RUN pip install ipdb
+ENTRYPOINT wandb login $WANDB_KEY && /bin/bash
diff --git a/ssl-sota/eval/get_data.py b/ssl-sota/eval/get_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1d4e6e29957a41f963b4fa98de7bafeb23542a
--- /dev/null
+++ b/ssl-sota/eval/get_data.py
@@ -0,0 +1,17 @@
+import torch
+
+
+def get_data(model, loader, output_size, device):
+ """ encodes whole dataset into embeddings """
+ xs = torch.empty(
+ len(loader), loader.batch_size, output_size, dtype=torch.float32, device=device
+ )
+ ys = torch.empty(len(loader), loader.batch_size, dtype=torch.long, device=device)
+ with torch.no_grad():
+ for i, (x, y) in enumerate(loader):
+ x = x.cuda()
+ xs[i] = model(x).to(device)
+ ys[i] = y.to(device)
+ xs = xs.view(-1, output_size)
+ ys = ys.view(-1)
+ return xs, ys
diff --git a/ssl-sota/eval/knn.py b/ssl-sota/eval/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b26709abf74b146eb7538a297ea0ea68c56b3663
--- /dev/null
+++ b/ssl-sota/eval/knn.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def eval_knn(x_train, y_train, x_test, y_test, k=200):
+ """ k-nearest neighbors classifier accuracy """
+ d = torch.cdist(x_test, x_train)
+ topk = torch.topk(d, k=k, dim=1, largest=False)
+ labels = y_train[topk.indices]
+ pred = torch.empty_like(y_test)
+ for i in range(len(labels)):
+ x = labels[i].unique(return_counts=True)
+ pred[i] = x[0][x[1].argmax()]
+
+ acc = (pred == y_test).float().mean().cpu().item()
+ del d, topk, labels, pred
+ return acc
diff --git a/ssl-sota/eval/lbfgs.py b/ssl-sota/eval/lbfgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b4d7eb949f4c90550b3b09ad50aac0ee932b7b
--- /dev/null
+++ b/ssl-sota/eval/lbfgs.py
@@ -0,0 +1,12 @@
+import torch
+from sklearn.linear_model import LogisticRegression
+
+
+def eval_lbfgs(x_train, y_train, x_test, y_test):
+ """ linear classifier accuracy (lbfgs method) """
+ clf = LogisticRegression(
+ random_state=1337, solver="lbfgs", max_iter=1000, n_jobs=-1
+ )
+ clf.fit(x_train, y_train)
+ pred = clf.predict(x_test)
+ return (torch.tensor(pred) == y_test).float().mean()
diff --git a/ssl-sota/eval/sgd.py b/ssl-sota/eval/sgd.py
new file mode 100644
index 0000000000000000000000000000000000000000..908405288f768142f11526a42ee4b8fef4ef7f24
--- /dev/null
+++ b/ssl-sota/eval/sgd.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+
+def eval_sgd(x_train, y_train, x_test, y_test, topk=[1, 5], epoch=500):
+ """ linear classifier accuracy (sgd) """
+ lr_start, lr_end = 1e-2, 1e-6
+ gamma = (lr_end / lr_start) ** (1 / epoch)
+ output_size = x_train.shape[1]
+ num_class = y_train.max().item() + 1
+ clf = nn.Linear(output_size, num_class)
+ clf.cuda()
+ clf.train()
+ optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6)
+ scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
+ criterion = nn.CrossEntropyLoss()
+
+ for ep in range(epoch):
+ perm = torch.randperm(len(x_train)).view(-1, 1000)
+ for idx in perm:
+ optimizer.zero_grad()
+ criterion(clf(x_train[idx]), y_train[idx]).backward()
+ optimizer.step()
+ scheduler.step()
+
+ clf.eval()
+ with torch.no_grad():
+ y_pred = clf(x_test)
+ pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
+ acc = {
+ t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
+ for t in topk
+ }
+ del clf
+ return acc
diff --git a/ssl-sota/methods/__init__.py b/ssl-sota/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7781711381539a4e7e29ae504b6073eb05434f9d
--- /dev/null
+++ b/ssl-sota/methods/__init__.py
@@ -0,0 +1,16 @@
+from .contrastive import Contrastive
+from .w_mse import WMSE
+from .byol import BYOL
+
+
+METHOD_LIST = ["contrastive", "w_mse", "byol"]
+
+
+def get_method(name):
+ assert name in METHOD_LIST
+ if name == "contrastive":
+ return Contrastive
+ elif name == "w_mse":
+ return WMSE
+ elif name == "byol":
+ return BYOL
diff --git a/ssl-sota/methods/base.py b/ssl-sota/methods/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ad54a5ea7b985789b2c5fd09580919c0feba121
--- /dev/null
+++ b/ssl-sota/methods/base.py
@@ -0,0 +1,61 @@
+import torch.nn as nn
+from model import get_model, get_head
+from eval.sgd import eval_sgd
+from eval.knn import eval_knn
+from eval.get_data import get_data
+import torch
+import tqdm
+
+class BaseMethod(nn.Module):
+ """
+ Base class for self-supervised loss implementation.
+ It includes encoder and head for training, evaluation function.
+ """
+
+ def __init__(self, cfg):
+ super().__init__()
+ self.model, self.out_size = get_model(cfg.arch, cfg.dataset)
+ self.head = get_head(self.out_size, cfg)
+ self.knn = cfg.knn
+ self.num_pairs = cfg.num_samples * (cfg.num_samples - 1) // 2
+ self.eval_head = cfg.eval_head
+ self.emb_size = cfg.emb
+
+ def forward(self, samples):
+ raise NotImplementedError
+
+ def get_acc(self, ds_clf, ds_test):
+ self.eval()
+ if self.eval_head:
+ model = lambda x: self.head(self.model(x))
+ out_size = self.emb_size
+ else:
+ model, out_size = self.model, self.out_size
+ # torch.cuda.empty_cache()
+ x_train, y_train = get_data(model, ds_clf, out_size, "cuda")
+ x_test, y_test = get_data(model, ds_test, out_size, "cuda")
+
+ acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn)
+ acc_linear = eval_sgd(x_train, y_train, x_test, y_test)
+ del x_train, y_train, x_test, y_test
+ self.train()
+ return acc_knn, acc_linear
+
+ def get_acc_knn(self, ds_clf, ds_test):
+ self.eval()
+ if self.eval_head:
+ model = lambda x: self.head(self.model(x))
+ out_size = self.emb_size
+ else:
+ model, out_size = self.model, self.out_size
+ # torch.cuda.empty_cache()
+ x_train, y_train = get_data(model, ds_clf, out_size, "cuda")
+ x_test, y_test = get_data(model, ds_test, out_size, "cuda")
+
+ acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn)
+ del x_train, y_train, x_test, y_test
+ self.train()
+ return acc_knn
+
+ def step(self, progress):
+ pass
diff --git a/ssl-sota/methods/byol.py b/ssl-sota/methods/byol.py
new file mode 100644
index 0000000000000000000000000000000000000000..de8c15ece6c69ab120e431e6e8039e094562abb5
--- /dev/null
+++ b/ssl-sota/methods/byol.py
@@ -0,0 +1,53 @@
+from itertools import chain
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from model import get_model, get_head
+from .base import BaseMethod
+from .norm_mse import norm_mse_loss
+
+
+class BYOL(BaseMethod):
+ """ implements BYOL loss https://arxiv.org/abs/2006.07733 """
+
+ def __init__(self, cfg):
+ """ init additional target and predictor networks """
+ super().__init__(cfg)
+ self.pred = nn.Sequential(
+ nn.Linear(cfg.emb, cfg.head_size),
+ nn.BatchNorm1d(cfg.head_size),
+ nn.ReLU(),
+ nn.Linear(cfg.head_size, cfg.emb),
+ )
+ self.model_t, _ = get_model(cfg.arch, cfg.dataset)
+ self.head_t = get_head(self.out_size, cfg)
+ for param in chain(self.model_t.parameters(), self.head_t.parameters()):
+ param.requires_grad = False
+ self.update_target(0)
+ self.byol_tau = cfg.byol_tau
+ self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss
+
+ def update_target(self, tau):
+ """ copy parameters from main network to target """
+ for t, s in zip(self.model_t.parameters(), self.model.parameters()):
+ t.data.copy_(t.data * tau + s.data * (1.0 - tau))
+ for t, s in zip(self.head_t.parameters(), self.head.parameters()):
+ t.data.copy_(t.data * tau + s.data * (1.0 - tau))
+
+ def forward(self, samples):
+ z = [self.pred(self.head(self.model(x))) for x in samples]
+ with torch.no_grad():
+ zt = [self.head_t(self.model_t(x)) for x in samples]
+
+ loss = 0
+ for i in range(len(samples) - 1):
+ for j in range(i + 1, len(samples)):
+ loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i])
+ loss /= self.num_pairs
+ return loss
+
+ def step(self, progress):
+ """ update target network with cosine increasing schedule """
+ tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2
+ self.update_target(tau)
diff --git a/ssl-sota/methods/contrastive.py b/ssl-sota/methods/contrastive.py
new file mode 100644
index 0000000000000000000000000000000000000000..043afe829d06377e955ceb16f29040f100d3f3db
--- /dev/null
+++ b/ssl-sota/methods/contrastive.py
@@ -0,0 +1,46 @@
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .base import BaseMethod
+
+
+def contrastive_loss(x0, x1, tau, norm):
+ # https://github.com/google-research/simclr/blob/master/objective.py
+ bsize = x0.shape[0]
+ target = torch.arange(bsize).cuda()
+ eye_mask = torch.eye(bsize).cuda() * 1e9
+ if norm:
+ x0 = F.normalize(x0, p=2, dim=1)
+ x1 = F.normalize(x1, p=2, dim=1)
+ logits00 = x0 @ x0.t() / tau - eye_mask
+ logits11 = x1 @ x1.t() / tau - eye_mask
+ logits01 = x0 @ x1.t() / tau
+ logits10 = x1 @ x0.t() / tau
+ return (
+ F.cross_entropy(torch.cat([logits01, logits00], dim=1), target)
+ + F.cross_entropy(torch.cat([logits10, logits11], dim=1), target)
+ ) / 2
+
+
+class Contrastive(BaseMethod):
+ """ implements contrastive loss https://arxiv.org/abs/2002.05709 """
+
+ def __init__(self, cfg):
+ """ init additional BN used after head """
+ super().__init__(cfg)
+ self.bn_last = nn.BatchNorm1d(cfg.emb)
+ self.loss_f = partial(contrastive_loss, tau=cfg.tau, norm=cfg.norm)
+
+ def forward(self, samples):
+ bs = len(samples[0])
+ h = [self.model(x.cuda(non_blocking=True)) for x in samples]
+ h = self.bn_last(self.head(torch.cat(h)))
+ loss = 0
+ for i in range(len(samples) - 1):
+ for j in range(i + 1, len(samples)):
+ x0 = h[i * bs : (i + 1) * bs]
+ x1 = h[j * bs : (j + 1) * bs]
+ loss += self.loss_f(x0, x1)
+ loss /= self.num_pairs
+ return loss
diff --git a/ssl-sota/methods/norm_mse.py b/ssl-sota/methods/norm_mse.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa303b1f0b2d6fbef440664bf89a373bafe65f60
--- /dev/null
+++ b/ssl-sota/methods/norm_mse.py
@@ -0,0 +1,7 @@
+import torch.nn.functional as F
+
+
+def norm_mse_loss(x0, x1):
+ x0 = F.normalize(x0)
+ x1 = F.normalize(x1)
+ return 2 - 2 * (x0 * x1).sum(dim=-1).mean()
diff --git a/ssl-sota/methods/w_mse.py b/ssl-sota/methods/w_mse.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7312766e66b1e8f1bc7e1c59fe71948b7d37bdb
--- /dev/null
+++ b/ssl-sota/methods/w_mse.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn.functional as F
+from .whitening import Whitening2d
+from .base import BaseMethod
+from .norm_mse import norm_mse_loss
+
+
+class WMSE(BaseMethod):
+ """ implements W-MSE loss """
+
+ def __init__(self, cfg):
+ """ init whitening transform """
+ super().__init__(cfg)
+ self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False)
+ self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss
+ self.w_iter = cfg.w_iter
+ self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size
+
+ def forward(self, samples):
+ bs = len(samples[0])
+ h = [self.model(x.cuda(non_blocking=True)) for x in samples]
+ h = self.head(torch.cat(h))
+ loss = 0
+ for _ in range(self.w_iter):
+ z = torch.empty_like(h)
+ perm = torch.randperm(bs).view(-1, self.w_size)
+ for idx in perm:
+ for i in range(len(samples)):
+ z[idx + i * bs] = self.whitening(h[idx + i * bs])
+ for i in range(len(samples) - 1):
+ for j in range(i + 1, len(samples)):
+ x0 = z[i * bs : (i + 1) * bs]
+ x1 = z[j * bs : (j + 1) * bs]
+ loss += self.loss_f(x0, x1)
+ loss /= self.w_iter * self.num_pairs
+ return loss
diff --git a/ssl-sota/methods/whitening.py b/ssl-sota/methods/whitening.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e965310af8ccf3708bf2c78e3fe4ab2e14b994
--- /dev/null
+++ b/ssl-sota/methods/whitening.py
@@ -0,0 +1,66 @@
+import torch
+import torch.nn as nn
+from torch.nn.functional import conv2d
+
+
+class Whitening2d(nn.Module):
+ def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0):
+ super(Whitening2d, self).__init__()
+ self.num_features = num_features
+ self.momentum = momentum
+ self.track_running_stats = track_running_stats
+ self.eps = eps
+
+ if self.track_running_stats:
+ self.register_buffer(
+ "running_mean", torch.zeros([1, self.num_features, 1, 1])
+ )
+ self.register_buffer("running_variance", torch.eye(self.num_features))
+
+ def forward(self, x):
+ x = x.unsqueeze(2).unsqueeze(3)
+ m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1)
+ if not self.training and self.track_running_stats: # for inference
+ m = self.running_mean
+ xn = x - m
+
+ T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1)
+ f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1)
+
+ eye = torch.eye(self.num_features).type(f_cov.type())
+
+ if not self.training and self.track_running_stats: # for inference
+ f_cov = self.running_variance
+
+ f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye
+
+ inv_sqrt = torch.linalg.solve_triangular(
+ torch.linalg.cholesky(f_cov_shrinked),
+ eye,
+ upper=False
+ )
+
+ inv_sqrt = inv_sqrt.contiguous().view(
+ self.num_features, self.num_features, 1, 1
+ )
+
+ decorrelated = conv2d(xn, inv_sqrt)
+
+ if self.training and self.track_running_stats:
+ self.running_mean = torch.add(
+ self.momentum * m.detach(),
+ (1 - self.momentum) * self.running_mean,
+ out=self.running_mean,
+ )
+ self.running_variance = torch.add(
+ self.momentum * f_cov.detach(),
+ (1 - self.momentum) * self.running_variance,
+ out=self.running_variance,
+ )
+
+ return decorrelated.squeeze(2).squeeze(2)
+
+ def extra_repr(self):
+ return "features={}, eps={}, momentum={}".format(
+ self.num_features, self.eps, self.momentum
+ )
diff --git a/ssl-sota/model.py b/ssl-sota/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..184e70dac2a325b7f65e2e77f42ca02b752a9e01
--- /dev/null
+++ b/ssl-sota/model.py
@@ -0,0 +1,29 @@
+import torch.nn as nn
+from torchvision import models
+
+
+def get_head(out_size, cfg):
+ """ creates projection head g() from config """
+ x = []
+ in_size = out_size
+ for _ in range(cfg.head_layers - 1):
+ x.append(nn.Linear(in_size, cfg.head_size))
+ if cfg.add_bn:
+ x.append(nn.BatchNorm1d(cfg.head_size))
+ x.append(nn.ReLU())
+ in_size = cfg.head_size
+ x.append(nn.Linear(in_size, cfg.emb))
+ return nn.Sequential(*x)
+
+
+def get_model(arch, dataset):
+ """ creates encoder E() by name and modifies it for dataset """
+ model = getattr(models, arch)(pretrained=False)
+ if dataset != "imagenet":
+ model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+ if dataset == "cifar10" or dataset == "cifar100":
+ model.maxpool = nn.Identity()
+ out_size = model.fc.in_features
+ model.fc = nn.Identity()
+
+ return nn.DataParallel(model), out_size
diff --git a/ssl-sota/scripts-pretrain-resnet50/cifar10.sh b/ssl-sota/scripts-pretrain-resnet50/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2f228d2fe76c49ddc349a1c61489d7bbdb1e5d69
--- /dev/null
+++ b/ssl-sota/scripts-pretrain-resnet50/cifar10.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+gpu=7
+dataset=cifar10
+arch=resnet18
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+
+lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc)
+lmbda=0.0078125
+echo ${lmbda}
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain-resnet50/cifar100.sh b/ssl-sota/scripts-pretrain-resnet50/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f4d167c8dc87ecde663692d63bbd9fdbda324c82
--- /dev/null
+++ b/ssl-sota/scripts-pretrain-resnet50/cifar100.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+gpu=9
+dataset=cifar100
+arch=resnet50
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+
+lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) #0.0078125
+lmbda=0.0078125
+echo ${lmbda}
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain-resnet50/imagenet.sh b/ssl-sota/scripts-pretrain-resnet50/imagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bf95107598978e96047aea309944ae36c7c99542
--- /dev/null
+++ b/ssl-sota/scripts-pretrain-resnet50/imagenet.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+is_mixup=true
+batch_size=1024 #128/gpu works
+lr_w=0.2 #0.2
+lr_b=0.0048 #0.0048
+lambda_mixup=0.004
+
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main_imagenet.py /mnt/store/wbandar1/imagenet1k/ --is_mixup ${is_mixup} --batch-size ${batch_size} --learning-rate-weights ${lr_w} --learning-rate-biases ${lr_b} --lambda_mixup ${lambda_mixup}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain-resnet50/stl10.sh b/ssl-sota/scripts-pretrain-resnet50/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9840f87429d3bdb3de8884c9af98cdf6236efb15
--- /dev/null
+++ b/ssl-sota/scripts-pretrain-resnet50/stl10.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+gpu=2
+dataset=stl10
+arch=resnet18
+feature_dim=1024
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625pochs=2000
+
+lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc)
+# lmbda=0.0078125 # found that this works better
+echo ${lmbda}
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain-resnet50/tinyimagenet.sh b/ssl-sota/scripts-pretrain-resnet50/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..08b27d3b7cfbbc5f9a6862ed3c6329c02aab614d
--- /dev/null
+++ b/ssl-sota/scripts-pretrain-resnet50/tinyimagenet.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+gpu=1
+dataset=tiny_imagenet
+arch=resnet18
+feature_dim=2048
+is_mixup=true # true, false
+batch_size=256
+epochs=2000
+lr=0.01
+lr_shed=cosine #"step", "cosine" # step, cosine
+mixup_loss_scale=4 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625
+
+lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc)
+# lmbda=0.0078125 # found out that fraction works fine for tiny_imagenet
+echo ${lmbda}
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M"
+screen -S "$session_name" -X detach
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/byol/cifar10.sh b/ssl-sota/scripts-pretrain/byol/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4dec6ce237f5729d06d7efa0d6d283244070e457
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/byol/cifar10.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=9
+dataset=cifar10
+method=byol
+model=resnet50
+eval_every=5
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/byol/cifar100.sh b/ssl-sota/scripts-pretrain/byol/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6b98ada9476e0328d59d3d84d9f93c76a0ca9a92
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/byol/cifar100.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=9
+dataset=cifar100
+method=byol
+model=resnet50
+eval_every=5
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/byol/stl10.sh b/ssl-sota/scripts-pretrain/byol/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..df7e1f163a8a4c932bf74d8d70902ad6c099e23a
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/byol/stl10.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+gpu=9
+dataset=stl10
+method=byol
+model=resnet50
+eval_every=5
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/byol/tinyimagenet.sh b/ssl-sota/scripts-pretrain/byol/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2de18464f2c6bc560f594a833808bf0158d81fd
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/byol/tinyimagenet.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+gpu=9
+dataset=tinyimagenet
+method=byol
+model=resnet50
+eval_every=5
+batch_size=320
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every} --bs ${batch_size}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/contrastive/cifar10.sh b/ssl-sota/scripts-pretrain/contrastive/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..80b74ab494ab8795e30c54f461d0a31587f331d1
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/contrastive/cifar10.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=cifar10
+method=contrastive
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/contrastive/cifar100.sh b/ssl-sota/scripts-pretrain/contrastive/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1666824e99ccbeaf6f41ae454ccf5e3611829561
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/contrastive/cifar100.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=cifar100
+method=contrastive
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/contrastive/stl10.sh b/ssl-sota/scripts-pretrain/contrastive/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3dab325f3d73fdca698469250b5a07f2d160295c
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/contrastive/stl10.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=stl10
+method=contrastive
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/contrastive/tinyimagenet.sh b/ssl-sota/scripts-pretrain/contrastive/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..492e1d8a0011d98a5bc19ea24895965b118e2792
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/contrastive/tinyimagenet.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=tinyimagenet
+method=contrastive
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/w-mse/cifar10.sh b/ssl-sota/scripts-pretrain/w-mse/cifar10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1569e03efa58933dd4890765d8f392dda64c4a0d
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/w-mse/cifar10.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=cifar10
+method=w_mse
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --bs 256 --emb 64 --w_size 128 --w_eps 10e-3 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/w-mse/cifar100.sh b/ssl-sota/scripts-pretrain/w-mse/cifar100.sh
new file mode 100644
index 0000000000000000000000000000000000000000..871230238774dbeb3674f9d1a2be102cc67f4fb3
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/w-mse/cifar100.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=cifar100
+method=w_mse
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --lr 3e-3 --bs 256 --emb 64 --w_size 128 --w_eps 10e-3 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/w-mse/stl10.sh b/ssl-sota/scripts-pretrain/w-mse/stl10.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5e72a0bd82cd38fcf336cbf0de5217db2f1e353d
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/w-mse/stl10.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=stl10
+method=w_mse
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 2e-3 --bs 256 --emb 128 --w_size 256 --w_eps 10e-3 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/scripts-pretrain/w-mse/tinyimagenet.sh b/ssl-sota/scripts-pretrain/w-mse/tinyimagenet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5240f85622738d3caa9cc6ed32e8a10d7e2727d4
--- /dev/null
+++ b/ssl-sota/scripts-pretrain/w-mse/tinyimagenet.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+gpu=9
+dataset=tinyimagenet
+method=w_mse
+model=resnet50
+
+timestamp=$(date +"%Y%m%d%H%M%S")
+session_name="python_session_$timestamp"
+echo ${session_name}
+screen -dmS "$session_name"
+screen -S "$session_name" -X stuff "conda activate ssl-aug^M"
+screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 2e-3 --bs 256 --emb 128 --w_size 256 --w_eps 10e-3 --method ${method} --arch ${model}^M"
+screen -S "$session_name" -X detachs
\ No newline at end of file
diff --git a/ssl-sota/test.py b/ssl-sota/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9330a5cc95f64aeb60d89a3b5d77e898d4228f5
--- /dev/null
+++ b/ssl-sota/test.py
@@ -0,0 +1,39 @@
+import torch
+from datasets import get_ds
+from cfg import get_cfg
+from methods import get_method
+
+from eval.sgd import eval_sgd
+from eval.knn import eval_knn
+from eval.lbfgs import eval_lbfgs
+from eval.get_data import get_data
+
+
+if __name__ == "__main__":
+ cfg = get_cfg()
+
+ model_full = get_method(cfg.method)(cfg)
+ model_full.cuda().eval()
+ if cfg.fname is None:
+ print("evaluating random model")
+ else:
+ model_full.load_state_dict(torch.load(cfg.fname))
+
+ ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers)
+ device = "cpu" if cfg.clf == "lbfgs" else "cuda"
+ if cfg.eval_head:
+ model = lambda x: model_full.head(model_full.model(x))
+ out_size = cfg.emb
+ else:
+ model = model_full.model
+ out_size = model_full.out_size
+ x_train, y_train = get_data(model, ds.clf, out_size, device)
+ x_test, y_test = get_data(model, ds.test, out_size, device)
+
+ if cfg.clf == "sgd":
+ acc = eval_sgd(x_train, y_train, x_test, y_test)
+ if cfg.clf == "knn":
+ acc = eval_knn(x_train, y_train, x_test, y_test)
+ elif cfg.clf == "lbfgs":
+ acc = eval_lbfgs(x_train, y_train, x_test, y_test)
+ print(acc)
diff --git a/ssl-sota/tf2/README.md b/ssl-sota/tf2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5a326b3a20289bc051b79fefd79b8762b86b02ab
--- /dev/null
+++ b/ssl-sota/tf2/README.md
@@ -0,0 +1,8 @@
+`w_mse_loss()` from `whitening.py` is W-MSE loss implementation for TensorFlow 2,
+it can be used with other popular implementations, e.g. [SimCLRv2](https://github.com/google-research/simclr/tree/master/tf2).
+
+
+Method uses global flags mechanism as in SimCLRv2:
+- `FLAGS.num_samples` - number of samples (d) generated from each image
+- `FLAGS.train_batch_size`
+- `FLAGS.proj_out_dim`
diff --git a/ssl-sota/tf2/whitening.py b/ssl-sota/tf2/whitening.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14650a7de184732326e2cd1df678905ef32ec64
--- /dev/null
+++ b/ssl-sota/tf2/whitening.py
@@ -0,0 +1,45 @@
+import tensorflow.compat.v2 as tf
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+
+class Whitening1D(tf.keras.layers.Layer):
+ def __init__(self, eps=0, **kwargs):
+ super(Whitening1D, self).__init__(**kwargs)
+ self.eps = eps
+
+ def call(self, x):
+ bs, c = x.shape
+ x_t = tf.transpose(x, (1, 0))
+ m = tf.reduce_mean(x_t, axis=1, keepdims=True)
+ f = x_t - m
+ ff_apr = tf.matmul(f, f, transpose_b=True) / (tf.cast(bs, tf.float32) - 1.0)
+ ff_apr_shrinked = (1 - self.eps) * ff_apr + tf.eye(c) * self.eps
+ sqrt = tf.linalg.cholesky(ff_apr_shrinked)
+ inv_sqrt = tf.linalg.triangular_solve(sqrt, tf.eye(c))
+ f_hat = tf.matmul(inv_sqrt, f)
+ decorelated = tf.transpose(f_hat, (1, 0))
+ return decorelated
+
+
+def w_mse_loss(x):
+ """ input x shape = (batch size * num_samples, proj_out_dim) """
+
+ w = Whitening1D()
+ num_samples = FLAGS.num_samples
+ num_slice = num_samples * FLAGS.train_batch_size // (2 * FLAGS.proj_out_dim)
+ x_split = tf.split(x, num_slice, 0)
+ for i in range(num_slice):
+ x_split[i] = w(x_split[i])
+ x = tf.concat(x_split, 0)
+ x = tf.math.l2_normalize(x, -1)
+
+ x_split = tf.split(x, num_samples, 0)
+ loss = 0
+ for i in range(num_samples - 1):
+ for j in range(i + 1, num_samples):
+ v = x_split[i] * x_split[j]
+ loss += 2 - 2 * tf.reduce_mean(tf.reduce_sum(v, -1))
+ loss /= num_samples * (num_samples - 1) // 2
+ return loss
diff --git a/ssl-sota/train.py b/ssl-sota/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86af7eeb12ae9e6de85b3b7e0853bf4ab858a50
--- /dev/null
+++ b/ssl-sota/train.py
@@ -0,0 +1,93 @@
+from tqdm import trange, tqdm
+import numpy as np
+import wandb
+import torch
+import torch.optim as optim
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts
+import torch.backends.cudnn as cudnn
+import os
+
+from cfg import get_cfg
+from datasets import get_ds
+from methods import get_method
+
+
+def get_scheduler(optimizer, cfg):
+ if cfg.lr_step == "cos":
+ return CosineAnnealingWarmRestarts(
+ optimizer,
+ T_0=cfg.epoch if cfg.T0 is None else cfg.T0,
+ T_mult=cfg.Tmult,
+ eta_min=cfg.eta_min,
+ )
+ elif cfg.lr_step == "step":
+ m = [cfg.epoch - a for a in cfg.drop]
+ return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma)
+ else:
+ return None
+
+
+if __name__ == "__main__":
+ cfg = get_cfg()
+ wandb.init(project=f"ssl-sota-{cfg.method}-{cfg.dataset}", config=cfg, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/')
+ run_id = wandb.run.id
+ # if not os.path.exists('../results'):
+ # os.mkdir('../results')
+ run_id_dir = os.path.join('/mnt/store/wbandar1/projects/ssl-aug-artifacts/', run_id)
+ if not os.path.exists(run_id_dir):
+ print('Creating directory {}'.format(run_id_dir))
+ os.mkdir(run_id_dir)
+
+ ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers)
+ model = get_method(cfg.method)(cfg)
+ model.cuda().train()
+ if cfg.fname is not None:
+ model.load_state_dict(torch.load(cfg.fname))
+
+ optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.adam_l2)
+ scheduler = get_scheduler(optimizer, cfg)
+
+ eval_every = cfg.eval_every
+ lr_warmup = 0 if cfg.lr_warmup else 500
+ cudnn.benchmark = True
+
+ for ep in trange(cfg.epoch, position=0):
+ loss_ep = []
+ iters = len(ds.train)
+ for n_iter, (samples, _) in enumerate(tqdm(ds.train, position=1)):
+ if lr_warmup < 500:
+ lr_scale = (lr_warmup + 1) / 500
+ for pg in optimizer.param_groups:
+ pg["lr"] = cfg.lr * lr_scale
+ lr_warmup += 1
+
+ optimizer.zero_grad()
+ loss = model(samples)
+ loss.backward()
+ optimizer.step()
+ loss_ep.append(loss.item())
+ model.step(ep / cfg.epoch)
+ if cfg.lr_step == "cos" and lr_warmup >= 500:
+ scheduler.step(ep + n_iter / iters)
+
+ if cfg.lr_step == "step":
+ scheduler.step()
+
+ if len(cfg.drop) and ep == (cfg.epoch - cfg.drop[0]):
+ eval_every = cfg.eval_every_drop
+
+ if (ep + 1) % eval_every == 0:
+ # acc_knn, acc = model.get_acc(ds.clf, ds.test)
+ # wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False)
+ acc_knn = model.get_acc_knn(ds.clf, ds.test)
+ wandb.log({"acc_knn": acc_knn}, commit=False)
+
+ if (ep + 1) % 100 == 0:
+ fname = f"/mnt/store/wbandar1/projects/ssl-aug-artifacts/{run_id}/{cfg.method}_{cfg.dataset}_{ep}.pt"
+ torch.save(model.state_dict(), fname)
+ wandb.log({"loss": np.mean(loss_ep), "ep": ep})
+
+ acc_knn, acc = model.get_acc(ds.clf, ds.test)
+ print('Final linear-acc: {}, knn-acc'.format(acc, acc_knn))
+ wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False)
+ wandb.finish()
\ No newline at end of file
diff --git a/transfer_datasets/README.md b/transfer_datasets/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e7ba755b6947e6c02ec5ccfd929f69cb41deaf83
--- /dev/null
+++ b/transfer_datasets/README.md
@@ -0,0 +1,19 @@
+## Datasets
+Download the trasfer learning dataset from the below links,
+
+* **ImageNet-1k**: https://image-net.org
+
+* **DTD**: https://www.robots.ox.ac.uk/~vgg/data/dtd/
+
+* **CUBirds**: http://www.vision.caltech.edu/datasets/cub_200_2011/ (no longer working: http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
+
+* **VGG Flower**: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
+
+* **Traffic Signs**: https://benchmark.ini.rub.de/gtsdb_dataset.html
+
+* **Aircraft**: https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/
+
+and put them under their respective paths, e.g., '../data/DTD'.
+
+## Training and linear evaluation
+All training and linear evaluation commands are provided in `main.sh` in the folder corresponding to the model.
\ No newline at end of file
diff --git a/transfer_datasets/__init__.py b/transfer_datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f07629070b6c5a60702ac83f4aa7a7e75bd9b434
--- /dev/null
+++ b/transfer_datasets/__init__.py
@@ -0,0 +1,16 @@
+from .aircraft import Aircraft
+from .cu_birds import CUBirds
+from .dtd import DTD
+from .fashionmnist import FashionMNIST
+from .mnist import MNIST
+from .traffic_sign import TrafficSign
+from .vgg_flower import VGGFlower
+
+TRANSFER_DATASET = {
+ 'aircraft': Aircraft,
+ 'cu_birds': CUBirds,
+ 'dtd': DTD,
+ 'fashionmnist': FashionMNIST,
+ 'mnist': MNIST,
+ 'traffic_sign': TrafficSign,
+ 'vgg_flower': VGGFlower}
\ No newline at end of file
diff --git a/transfer_datasets/aircraft.py b/transfer_datasets/aircraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8dee6fa617fd6b73d983c25fb1aaa555636dd82
--- /dev/null
+++ b/transfer_datasets/aircraft.py
@@ -0,0 +1,74 @@
+import os
+import numpy as np
+from PIL import Image
+from os.path import join
+from collections import defaultdict
+import torch.utils.data as data
+
+DATA_ROOTS = 'data/Aircraft'
+
+# url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
+# wget http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz
+# python
+# from torchvision.datasets.utils import extract_archive
+# extract_archive("fgvc-aircraft-2013b.tar.gz")
+# Download and preprocess: https://github.com/lvyilin/pytorch-fgvc-dataset/blob/master/aircraft.py
+
+# class_types = ('variant', 'family', 'manufacturer')
+# splits = ('train', 'val', 'trainval', 'test')
+# img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
+
+class Aircraft(data.Dataset):
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ self.root = root
+ self.train = train
+ self.image_transforms = image_transforms
+ paths, bboxes, labels = self.load_images()
+ self.paths = paths
+ self.bboxes = bboxes
+ self.labels = labels
+
+ def load_images(self):
+ split = 'trainval' if self.train else 'test'
+ variant_path = os.path.join(self.root, 'data', 'images_variant_%s.txt'%split)
+ with open(variant_path, 'r') as f:
+ names_to_variants = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
+ names_to_variants = dict(names_to_variants)
+ variants_to_names = defaultdict(list)
+ for name, variant in names_to_variants.items():
+ variants_to_names[variant].append(name)
+ variants = sorted(list(set(variants_to_names.keys())))
+
+ names_to_bboxes = self.get_bounding_boxes()
+ split_files, split_labels, split_bboxes = [], [], []
+ for variant_id, variant in enumerate(variants):
+ class_files = [join(self.root, 'data', 'images', '%s.jpg'%filename) for filename in sorted(variants_to_names[variant])]
+ bboxes = [names_to_bboxes[name] for name in sorted(variants_to_names[variant])]
+ labels = list([variant_id] * len(class_files))
+ split_files += class_files
+ split_labels += labels
+ split_bboxes += bboxes
+ return split_files, split_bboxes, split_labels
+
+ def get_bounding_boxes(self):
+ bboxes_path = os.path.join(self.root, 'data', 'images_box.txt')
+ with open(bboxes_path, 'r') as f:
+ names_to_bboxes = [line.split('\n')[0].split(' ') for line in f.readlines()]
+ names_to_bboxes = dict((name, list(map(int, (xmin, ymin, xmax, ymax)))) for name, xmin, ymin, xmax, ymax in names_to_bboxes)
+ return names_to_bboxes
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = self.paths[index]
+ bbox = tuple(self.bboxes[index])
+ label = self.labels[index]
+
+ image = Image.open(path).convert(mode='RGB')
+ image = image.crop(bbox)
+
+ if self.image_transforms:
+ image = self.image_transforms(image)
+ return image, label
\ No newline at end of file
diff --git a/transfer_datasets/cu_birds.py b/transfer_datasets/cu_birds.py
new file mode 100644
index 0000000000000000000000000000000000000000..353ae453dee70051ef832d6ec51f67c2e99b7353
--- /dev/null
+++ b/transfer_datasets/cu_birds.py
@@ -0,0 +1,61 @@
+import os
+import numpy as np
+from PIL import Image
+import torch.utils.data as data
+
+DATA_ROOTS = 'data/CUBirds'
+
+# wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
+# tar -xvzf CUB_200_2011.tgz
+
+class CUBirds(data.Dataset):
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ self.root = root
+ self.train = train
+ self.image_transforms = image_transforms
+ paths, labels = self.load_images()
+ self.paths, self.labels = paths, labels
+
+ def load_images(self):
+ image_info_path = os.path.join(self.root, 'images.txt')
+ with open(image_info_path, 'r') as f:
+ image_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
+ image_info = dict(image_info)
+
+ # load image to label information
+ label_info_path = os.path.join(self.root, 'image_class_labels.txt')
+ with open(label_info_path, 'r') as f:
+ label_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
+ label_info = dict(label_info)
+
+ # load train test split
+ train_test_info_path = os.path.join(self.root, 'train_test_split.txt')
+ with open(train_test_info_path, 'r') as f:
+ train_test_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
+ train_test_info = dict(train_test_info)
+
+ all_paths, all_labels = [], []
+ for index, image_path in image_info.items():
+ label = label_info[index]
+ split = int(train_test_info[index])
+ if self.train:
+ if split == 1:
+ all_paths.append(image_path)
+ all_labels.append(label)
+ else:
+ if split == 0:
+ all_paths.append(image_path)
+ all_labels.append(label)
+ return all_paths, all_labels
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = os.path.join(self.root, 'images', self.paths[index])
+ label = int(self.labels[index]) - 1
+ image = Image.open(path).convert(mode='RGB')
+ if self.image_transforms:
+ image = self.image_transforms(image)
+ return image, label
\ No newline at end of file
diff --git a/transfer_datasets/dtd.py b/transfer_datasets/dtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..24c52439d1bc14643be43ce56490d840856bf76c
--- /dev/null
+++ b/transfer_datasets/dtd.py
@@ -0,0 +1,69 @@
+import os
+import copy
+import numpy as np
+from PIL import Image
+from os.path import join
+from itertools import chain
+from collections import defaultdict
+
+import torch
+import torch.utils.data as data
+from torchvision import transforms
+
+DATA_ROOTS = 'data/DTD'
+
+# wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz
+# tar -xvzf dtd-r1.0.1.tar.gz
+
+class DTD(data.Dataset):
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ self.root = root
+ self.train = train
+ self.image_transforms = image_transforms
+ paths, labels = self.load_images()
+ self.paths, self.labels = paths, labels
+
+ def load_images(self):
+ if self.train:
+ train_info_path = os.path.join(self.root, 'labels', 'train1.txt')
+ with open(train_info_path, 'r') as f:
+ train_info = [line.split('\n')[0] for line in f.readlines()]
+
+ val_info_path = os.path.join(self.root, 'labels', 'val1.txt')
+ with open(val_info_path, 'r') as f:
+ val_info = [line.split('\n')[0] for line in f.readlines()]
+ split_info = train_info + val_info
+
+ else:
+ test_info_path = os.path.join(self.root, 'labels', 'test1.txt')
+ with open(test_info_path, 'r') as f:
+ split_info = [line.split('\n')[0] for line in f.readlines()]
+
+ # pull out categoires from paths
+ categories = []
+ for row in split_info:
+ image_path = row
+ category = image_path.split('/')[0]
+ categories.append(category)
+ categories = sorted(list(set(categories)))
+
+ all_paths, all_labels = [], []
+ for row in split_info:
+ image_path = row
+ category = image_path.split('/')[0]
+ label = categories.index(category)
+ all_paths.append(join(self.root, 'images', image_path))
+ all_labels.append(label)
+ return all_paths, all_labels
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = self.paths[index]
+ label = self.labels[index]
+ image = Image.open(path).convert(mode='RGB')
+ if self.image_transforms:
+ image = self.image_transforms(image)
+ return image, label
\ No newline at end of file
diff --git a/transfer_datasets/fashionmnist.py b/transfer_datasets/fashionmnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1455ef43dbdbbb388240d92f7a121bd06f7d75
--- /dev/null
+++ b/transfer_datasets/fashionmnist.py
@@ -0,0 +1,28 @@
+import os
+import copy
+from PIL import Image
+import numpy as np
+
+import torch
+import torch.utils.data as data
+from torchvision import transforms, datasets
+
+DATA_ROOTS = 'data'
+
+class FashionMNIST(data.Dataset):
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ if not os.path.isdir(root):
+ os.makedirs(root)
+ self.image_transforms = image_transforms
+ self.dataset = datasets.mnist.FashionMNIST(root, train=train, download=True)
+
+ def __getitem__(self, index):
+ img, target = self.dataset.data[index], int(self.dataset.targets[index])
+ img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
+ if self.image_transforms is not None:
+ img = self.image_transforms(img)
+ return img, target
+
+ def __len__(self):
+ return len(self.dataset)
\ No newline at end of file
diff --git a/transfer_datasets/mnist.py b/transfer_datasets/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab88a5ef8c3d82d7115feb4d4f74402cb7287d2
--- /dev/null
+++ b/transfer_datasets/mnist.py
@@ -0,0 +1,28 @@
+import os
+import copy
+from PIL import Image
+import numpy as np
+
+import torch
+import torch.utils.data as data
+from torchvision import transforms, datasets
+
+DATA_ROOTS = 'data'
+
+class MNIST(data.Dataset):
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ if not os.path.isdir(root):
+ os.makedirs(root)
+ self.image_transforms = image_transforms
+ self.dataset = datasets.mnist.MNIST(root, train=train, download=True)
+
+ def __getitem__(self, index):
+ img, target = self.dataset.data[index], int(self.dataset.targets[index])
+ img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
+ if self.image_transforms is not None:
+ img = self.image_transforms(img)
+ return img, target
+
+ def __len__(self):
+ return len(self.dataset)
\ No newline at end of file
diff --git a/transfer_datasets/traffic_sign.py b/transfer_datasets/traffic_sign.py
new file mode 100644
index 0000000000000000000000000000000000000000..708e989a8e2abe3f01c46ec273e2dba0c461734f
--- /dev/null
+++ b/transfer_datasets/traffic_sign.py
@@ -0,0 +1,65 @@
+import os
+import copy
+import json
+import operator
+import numpy as np
+from PIL import Image
+from glob import glob
+from os.path import join
+from itertools import chain
+from scipy.io import loadmat
+from collections import defaultdict
+
+import torch
+import torch.utils.data as data
+from torchvision import transforms
+
+DATA_ROOTS = 'data/TrafficSign'
+
+# wget https://sid.erda.dk/public/archives/ff17dc924eba88d5d01a807357d6614c/FullIJCNN2013.zip
+# unzip FullIJCNN2013.zip
+
+class TrafficSign(data.Dataset):
+ NUM_CLASSES = 43
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ self.root = root
+ self.train = train
+ self.image_transforms = image_transforms
+ paths, labels = self.load_images()
+ self.paths, self.labels = paths, labels
+
+ def load_images(self):
+ split = 'Final_Training'
+ rs = np.random.RandomState(42)
+ all_filepaths, all_labels = [], []
+ for class_i in range(self.NUM_CLASSES):
+ # class_dir_i = join(self.root, split, 'Images', '{:05d}'.format(class_i))
+ class_dir_i = join(self.root, '{:02d}'.format(class_i))
+ image_paths = glob(join(class_dir_i, "*.ppm"))
+ # train test splitting
+ image_paths = np.array(image_paths)
+ num = len(image_paths)
+ indexer = np.arange(num)
+ rs.shuffle(indexer)
+ image_paths = image_paths[indexer].tolist()
+ if self.train:
+ image_paths = image_paths[:int(0.8 * num)]
+ else:
+ image_paths = image_paths[int(0.8 * num):]
+ labels = [class_i] * len(image_paths)
+ all_filepaths.extend(image_paths)
+ all_labels.extend(labels)
+
+ return all_filepaths, all_labels
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = self.paths[index]
+ label = self.labels[index]
+ image = Image.open(path).convert(mode='RGB')
+ if self.image_transforms:
+ image = self.image_transforms(image)
+ return image, label
\ No newline at end of file
diff --git a/transfer_datasets/vgg_flower.py b/transfer_datasets/vgg_flower.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fe055a354ab094b943f0ac46494dd9b15d4267d
--- /dev/null
+++ b/transfer_datasets/vgg_flower.py
@@ -0,0 +1,73 @@
+import os
+import copy
+import json
+import operator
+import numpy as np
+from PIL import Image
+from os.path import join
+from itertools import chain
+from scipy.io import loadmat
+from collections import defaultdict
+
+import torch
+import torch.utils.data as data
+from torchvision import transforms
+
+DATA_ROOTS = 'data/VGGFlower'
+
+# wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
+# tar -xvzf 102flowers.tgz
+# rename file to VGGFlower
+# cd VGGFlower
+# wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat
+
+
+class VGGFlower(data.Dataset):
+ def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
+ super().__init__()
+ self.root = root
+ self.train = train
+ self.image_transforms = image_transforms
+ paths, labels = self.load_images()
+ self.paths, self.labels = paths, labels
+
+ def load_images(self):
+ rs = np.random.RandomState(42)
+ imagelabels_path = os.path.join(self.root, 'imagelabels.mat')
+ with open(imagelabels_path, 'rb') as f:
+ labels = loadmat(f)['labels'][0]
+
+ all_filepaths = defaultdict(list)
+ for i, label in enumerate(labels):
+ # all_filepaths[label].append(os.path.join(self.root, 'jpg', 'image_{:05d}.jpg'.format(i+1)))
+ all_filepaths[label].append(os.path.join(self.root, 'image_{:05d}.jpg'.format(i+1)))
+ # train test split
+ split_filepaths, split_labels = [], []
+ for label, paths in all_filepaths.items():
+ num = len(paths)
+ paths = np.array(paths)
+ indexer = np.arange(num)
+ rs.shuffle(indexer)
+ paths = paths[indexer].tolist()
+
+ if self.train:
+ paths = paths[:int(0.8 * num)]
+ else:
+ paths = paths[int(0.8 * num):]
+
+ labels = [label] * len(paths)
+ split_filepaths.extend(paths)
+ split_labels.extend(labels)
+
+ return split_filepaths, split_labels
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = self.paths[index]
+ label = int(self.labels[index]) - 1
+ image = Image.open(path).convert(mode='RGB')
+ if self.image_transforms:
+ image = self.image_transforms(image)
+ return image, label
\ No newline at end of file
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32570c63b3c9fd39631390c8c4637303f44872c1
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,102 @@
+from PIL import Image
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+from augmentations.augmentations_cifar import aug_cifar
+from augmentations.augmentations_tiny import aug_tiny
+from augmentations.augmentations_stl import aug_stl
+
+# for cifar10 / cifar100 (32x32)
+class CifarPairTransform:
+ def __init__(self, train_transform = True, pair_transform = True):
+ if train_transform is True:
+ self.transform = transforms.Compose([
+ transforms.RandomResizedCrop(32),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.ToTensor(),
+ transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
+ else:
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
+ self.pair_transform = pair_transform
+ def __call__(self, x):
+ if self.pair_transform is True:
+ y1 = self.transform(x)
+ y2 = self.transform(x)
+ return y1, y2
+ else:
+ return self.transform(x)
+
+# for tiny_imagenet (64x64)
+class TinyImageNetPairTransform:
+ def __init__(self, train_transform = True, pair_transform = True):
+ if train_transform is True:
+ self.transform = transforms.Compose([
+ transforms.RandomApply(
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4,
+ saturation=0.4, hue=0.1)],
+ p=0.8
+ ),
+ transforms.RandomGrayscale(p=0.1),
+ transforms.RandomResizedCrop(
+ 64,
+ scale=(0.2, 1.0),
+ ratio=(0.75, (4 / 3)),
+ interpolation=Image.BICUBIC,
+ ),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.ToTensor(),
+ transforms.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))
+ ])
+ else:
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))
+ ])
+ self.pair_transform = pair_transform
+ def __call__(self, x):
+ if self.pair_transform is True:
+ y1 = self.transform(x)
+ y2 = self.transform(x)
+ return y1, y2
+ else:
+ return self.transform(x)
+
+# for stl10 (96x96)
+class StlPairTransform:
+ def __init__(self, train_transform = True, pair_transform = True):
+ if train_transform is True:
+ self.transform = transforms.Compose([
+ transforms.RandomApply(
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4,
+ saturation=0.4, hue=0.1)],
+ p=0.8
+ ),
+ transforms.RandomGrayscale(p=0.1),
+ transforms.RandomResizedCrop(
+ 64,
+ scale=(0.2, 1.0),
+ ratio=(0.75, (4 / 3)),
+ interpolation=Image.BICUBIC,
+ ),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.ToTensor(),
+ transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))
+ ])
+ else:
+ self.transform = transforms.Compose([
+ transforms.Resize(70, interpolation=Image.BICUBIC),
+ transforms.CenterCrop(64),
+ transforms.ToTensor(),
+ transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))
+ ])
+ self.pair_transform = pair_transform
+ def __call__(self, x):
+ if self.pair_transform is True:
+ y1 = self.transform(x)
+ y2 = self.transform(x)
+ return y1, y2
+ else:
+ return self.transform(x)
\ No newline at end of file