diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..15d0543cce6babec7fd594a74653b1d65f95dfaf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +figs/in-loss-reg.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..22655ba5065b54e57670ea6ef465138cc91245a5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Wele Gedara Chaminda Bandara + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md new file mode 100644 index 0000000000000000000000000000000000000000..f8a5479f28fd299fa88a5f2e951f85ece5738a4f --- /dev/null +++ b/MODEL_ZOO.md @@ -0,0 +1,20 @@ +The following links provide pre-trained models: +# ResNet-18 Pre-trained Models +| Dataset | d | Lambda_BT | Lambda_Reg | Path to Pretrained Model | KNN Acc. | Linear Acc. | +| ---------- | --- | ---------- | ---------- | ------------------------ | -------- | ----------- | +| CIFAR-10 | 1024 | 0.0078125 | 4.0 | 4wdhbpcf_0.0078125_1024_256_cifar10_model.pth | 90.52 | 92.58 | +| CIFAR-100 | 1024 | 0.0078125 | 4.0 | 76kk7scz_0.0078125_1024_256_cifar100_model.pth | 61.25 | 69.31 | +| TinyImageNet | 1024 | 0.0009765 | 4.0 | 02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth | 38.11 | 51.67 | +| STL-10 | 1024 | 0.0078125 | 2.0 | i7det4xq_0.0078125_1024_256_stl10_model.pth | 88.94 | 91.02 | + + + +# ResNet-50 Pre-trained Models +| Dataset | d | Lambda_BT | Lambda_Reg | Path to Pretrained Model | KNN Acc. | Linear Acc. | +| ---------- | --- | ---------- | ---------- | ------------------------ | -------- | ----------- | +| CIFAR-10 | 1024 | 0.0078125 | 4.0 | v3gwgusq_0.0078125_1024_256_cifar10_model.pth | 91.39 | 93.89 | +| CIFAR-100 | 1024 | 0.0078125 | 4.0 | z6ngefw7_0.0078125_1024_256_cifar100_model_2000.pth | 64.32 | 72.51 | +| TinyImageNet | 1024 | 0.0009765 | 4.0 | kxlkigsv_0.0009765_1024_256_tiny_imagenet_model_2000.pth | 42.21 | 51.84 | +| STL-10 | 1024 | 0.0078125 | 2.0 | pbknx38b_0.0078125_1024_256_stl10_model.pth | 87.79 | 91.70 | +| ImageNet | 1024 | 0.0051 | 0.1 | 13awtq23_0.0051_8192_1024_imagenet_0.1_resnet50.pth | - | 72.1 | + diff --git a/README.md b/README.md index 79e17098b2528fb1d64eb7ce558a0dcdb529ddf5..ea025e9477a50efa105be88e27a6f8d175614ee9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,3 @@ ---- -license: mit ---- # Mixed Barlow Twins [**Guarding Barlow Twins Against Overfitting with Mixed Samples**](https://arxiv.org/abs/2312.02151)
diff --git a/augmentations/augmentations_cifar.py b/augmentations/augmentations_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ec6c8a30602612846409fdffb8c8297cb95da1 --- /dev/null +++ b/augmentations/augmentations_cifar.py @@ -0,0 +1,190 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base augmentations operators.""" + +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + +# ImageNet code should change this value +IMAGE_SIZE = 32 +import torch +from torchvision import transforms + + +def int_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval . + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + An int that results from scaling `maxval` according to `level`. + """ + return int(level * maxval / 10) + + +def float_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval. + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + A float that results from scaling `maxval` according to `level`. + """ + return float(level) * maxval / 10. + + +def sample_level(n): + return np.random.uniform(low=0.1, high=n) + + +def autocontrast(pil_img, _): + return ImageOps.autocontrast(pil_img) + + +def equalize(pil_img, _): + return ImageOps.equalize(pil_img) + + +def posterize(pil_img, level): + level = int_parameter(sample_level(level), 4) + return ImageOps.posterize(pil_img, 4 - level) + + +def rotate(pil_img, level): + degrees = int_parameter(sample_level(level), 30) + if np.random.uniform() > 0.5: + degrees = -degrees + return pil_img.rotate(degrees, resample=Image.BILINEAR) + + +def solarize(pil_img, level): + level = int_parameter(sample_level(level), 256) + return ImageOps.solarize(pil_img, 256 - level) + + +def shear_x(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, level, 0, 0, 1, 0), + resample=Image.BILINEAR) + + +def shear_y(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, level, 1, 0), + resample=Image.BILINEAR) + + +def translate_x(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, level, 0, 1, 0), + resample=Image.BILINEAR) + + +def translate_y(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, 0, 1, level), + resample=Image.BILINEAR) + + +# operation that overlaps with ImageNet-C's test set +def color(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Color(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def contrast(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Contrast(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def brightness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Brightness(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def sharpness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Sharpness(pil_img).enhance(level) + +def random_resized_crop(pil_img, level): + return transforms.RandomResizedCrop(32)(pil_img) + +def random_flip(pil_img, level): + return transforms.RandomHorizontalFlip(p=0.5)(pil_img) + +def grayscale(pil_img, level): + return transforms.Grayscale(num_output_channels=3)(pil_img) + +augmentations = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, grayscale #random_resized_crop, random_flip +] + +augmentations_all = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, color, contrast, brightness, sharpness, grayscale #, random_resized_crop, random_flip +] + +def aug_cifar(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3): + """Perform AugMix augmentations and compute mixture. + + Args: + image: PIL.Image input image + preprocess: Preprocessing function which should return a torch tensor. + + Returns: + mixed: Augmented and mixed image. + """ + aug_list = augmentations_all + # if args.all_ops: + # aug_list = augmentations.augmentations_all + + ws = np.float32(np.random.dirichlet([1] * mixture_width)) + m = np.float32(np.random.beta(1, 1)) + + mix = torch.zeros_like(preprocess(image)) + for i in range(mixture_width): + image_aug = image.copy() + depth = mixture_depth if mixture_depth > 0 else np.random.randint( + 1, 4) + for _ in range(depth): + op = np.random.choice(aug_list) + image_aug = op(image_aug, aug_severity) + # Preprocessing commutes since all coefficients are convex + mix += ws[i] * preprocess(image_aug) + + # mixed = (1 - m) * preprocess(image) + m * mix + return mix \ No newline at end of file diff --git a/augmentations/augmentations_stl.py b/augmentations/augmentations_stl.py new file mode 100644 index 0000000000000000000000000000000000000000..fb62c61fe8036cfe6c006c2dc62fa6173cf45bc1 --- /dev/null +++ b/augmentations/augmentations_stl.py @@ -0,0 +1,190 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base augmentations operators.""" + +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + +# ImageNet code should change this value +IMAGE_SIZE = 64 +import torch +from torchvision import transforms + + +def int_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval . + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + An int that results from scaling `maxval` according to `level`. + """ + return int(level * maxval / 10) + + +def float_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval. + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + A float that results from scaling `maxval` according to `level`. + """ + return float(level) * maxval / 10. + + +def sample_level(n): + return np.random.uniform(low=0.1, high=n) + + +def autocontrast(pil_img, _): + return ImageOps.autocontrast(pil_img) + + +def equalize(pil_img, _): + return ImageOps.equalize(pil_img) + + +def posterize(pil_img, level): + level = int_parameter(sample_level(level), 4) + return ImageOps.posterize(pil_img, 4 - level) + + +def rotate(pil_img, level): + degrees = int_parameter(sample_level(level), 30) + if np.random.uniform() > 0.5: + degrees = -degrees + return pil_img.rotate(degrees, resample=Image.BILINEAR) + + +def solarize(pil_img, level): + level = int_parameter(sample_level(level), 256) + return ImageOps.solarize(pil_img, 256 - level) + + +def shear_x(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, level, 0, 0, 1, 0), + resample=Image.BILINEAR) + + +def shear_y(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, level, 1, 0), + resample=Image.BILINEAR) + + +def translate_x(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, level, 0, 1, 0), + resample=Image.BILINEAR) + + +def translate_y(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, 0, 1, level), + resample=Image.BILINEAR) + + +# operation that overlaps with ImageNet-C's test set +def color(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Color(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def contrast(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Contrast(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def brightness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Brightness(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def sharpness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Sharpness(pil_img).enhance(level) + +def random_resized_crop(pil_img, level): + return transforms.RandomResizedCrop(32)(pil_img) + +def random_flip(pil_img, level): + return transforms.RandomHorizontalFlip(p=0.5)(pil_img) + +def grayscale(pil_img, level): + return transforms.Grayscale(num_output_channels=3)(pil_img) + +augmentations = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, grayscale #random_resized_crop, random_flip +] + +augmentations_all = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, color, contrast, brightness, sharpness, grayscale #, random_resized_crop, random_flip +] + +def aug_stl(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3): + """Perform AugMix augmentations and compute mixture. + + Args: + image: PIL.Image input image + preprocess: Preprocessing function which should return a torch tensor. + + Returns: + mixed: Augmented and mixed image. + """ + aug_list = augmentations + # if args.all_ops: + # aug_list = augmentations.augmentations_all + + ws = np.float32(np.random.dirichlet([1] * mixture_width)) + m = np.float32(np.random.beta(1, 1)) + + mix = torch.zeros_like(preprocess(image)) + for i in range(mixture_width): + image_aug = image.copy() + depth = mixture_depth if mixture_depth > 0 else np.random.randint( + 1, 4) + for _ in range(depth): + op = np.random.choice(aug_list) + image_aug = op(image_aug, aug_severity) + # Preprocessing commutes since all coefficients are convex + mix += ws[i] * preprocess(image_aug) + + mixed = (1 - m) * preprocess(image) + m * mix + return mixed \ No newline at end of file diff --git a/augmentations/augmentations_tiny.py b/augmentations/augmentations_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..0dde30244e7a4daa7c451fb5e3ad9b54e62a4404 --- /dev/null +++ b/augmentations/augmentations_tiny.py @@ -0,0 +1,190 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base augmentations operators.""" + +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + +# ImageNet code should change this value +IMAGE_SIZE = 64 +import torch +from torchvision import transforms + + +def int_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval . + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + An int that results from scaling `maxval` according to `level`. + """ + return int(level * maxval / 10) + + +def float_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval. + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + A float that results from scaling `maxval` according to `level`. + """ + return float(level) * maxval / 10. + + +def sample_level(n): + return np.random.uniform(low=0.1, high=n) + + +def autocontrast(pil_img, _): + return ImageOps.autocontrast(pil_img) + + +def equalize(pil_img, _): + return ImageOps.equalize(pil_img) + + +def posterize(pil_img, level): + level = int_parameter(sample_level(level), 4) + return ImageOps.posterize(pil_img, 4 - level) + + +def rotate(pil_img, level): + degrees = int_parameter(sample_level(level), 30) + if np.random.uniform() > 0.5: + degrees = -degrees + return pil_img.rotate(degrees, resample=Image.BILINEAR) + + +def solarize(pil_img, level): + level = int_parameter(sample_level(level), 256) + return ImageOps.solarize(pil_img, 256 - level) + + +def shear_x(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, level, 0, 0, 1, 0), + resample=Image.BILINEAR) + + +def shear_y(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, level, 1, 0), + resample=Image.BILINEAR) + + +def translate_x(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, level, 0, 1, 0), + resample=Image.BILINEAR) + + +def translate_y(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, 0, 1, level), + resample=Image.BILINEAR) + + +# operation that overlaps with ImageNet-C's test set +def color(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Color(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def contrast(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Contrast(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def brightness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Brightness(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def sharpness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Sharpness(pil_img).enhance(level) + +def random_resized_crop(pil_img, level): + return transforms.RandomResizedCrop(32)(pil_img) + +def random_flip(pil_img, level): + return transforms.RandomHorizontalFlip(p=0.5)(pil_img) + +def grayscale(pil_img, level): + return transforms.Grayscale(num_output_channels=3)(pil_img) + +augmentations = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, grayscale #random_resized_crop, random_flip +] + +augmentations_all = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, color, contrast, brightness, sharpness, grayscale #, random_resized_crop, random_flip +] + +def aug_tiny(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3): + """Perform AugMix augmentations and compute mixture. + + Args: + image: PIL.Image input image + preprocess: Preprocessing function which should return a torch tensor. + + Returns: + mixed: Augmented and mixed image. + """ + aug_list = augmentations + # if args.all_ops: + # aug_list = augmentations.augmentations_all + + ws = np.float32(np.random.dirichlet([1] * mixture_width)) + m = np.float32(np.random.beta(1, 1)) + + mix = torch.zeros_like(preprocess(image)) + for i in range(mixture_width): + image_aug = image.copy() + depth = mixture_depth if mixture_depth > 0 else np.random.randint( + 1, 4) + for _ in range(depth): + op = np.random.choice(aug_list) + image_aug = op(image_aug, aug_severity) + # Preprocessing commutes since all coefficients are convex + mix += ws[i] * preprocess(image_aug) + + mixed = (1 - m) * preprocess(image) + m * mix + return mixed \ No newline at end of file diff --git a/data_statistics.py b/data_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..ae212dc093d48169255e0b7ee3358d7219881c0c --- /dev/null +++ b/data_statistics.py @@ -0,0 +1,61 @@ +def get_data_mean_and_stdev(dataset): + if dataset == 'CIFAR10' or dataset == 'CIFAR100': + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + elif dataset == 'STL-10': + mean = [0.491, 0.482, 0.447] + std = [0.247, 0.244, 0.262] + elif dataset == 'ImageNet': + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + elif dataset == 'aircraft': + mean = [0.486, 0.507, 0.525] + std = [0.266, 0.260, 0.276] + elif dataset == 'cu_birds': + mean = [0.483, 0.491, 0.424] + std = [0.228, 0.224, 0.259] + elif dataset == 'dtd': + mean = [0.533, 0.474, 0.426] + std = [0.261, 0.250, 0.259] + elif dataset == 'fashionmnist': + mean = [0.348, 0.348, 0.348] + std = [0.347, 0.347, 0.347] + elif dataset == 'mnist': + mean = [0.170, 0.170, 0.170] + std = [0.320, 0.320, 0.320] + elif dataset == 'traffic_sign': + mean = [0.335, 0.291, 0.295] + std = [0.267, 0.249, 0.251] + elif dataset == 'vgg_flower': + mean = [0.518, 0.410, 0.329] + std = [0.296, 0.249, 0.285] + else: + raise Exception('Dataset %s not supported.'%dataset) + return mean, std + +def get_data_nclass(dataset): + if dataset == 'cifar10': + nclass = 10 + elif dataset == 'cifar100cifar10': + nclass = 100 + elif dataset == 'stl-10': + nclass = 10 + elif dataset == 'ImageNet': + nclass = 1000 + elif dataset == 'aircraft': + nclass = 102 + elif dataset == 'cu_birds': + nclass = 200 + elif dataset == 'dtd': + nclass = 47 + elif dataset == 'fashionmnist': + nclass = 10 + elif dataset == 'mnist': + nclass = 10 + elif dataset == 'traffic_sign': + nclass = 43 + elif dataset == 'vgg_flower': + nclass = 102 + else: + raise Exception('Dataset %s not supported.'%dataset) + return nclass \ No newline at end of file diff --git a/download_imagenet.sh b/download_imagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b0237336a36713a219486f3a4084d653a748898 --- /dev/null +++ b/download_imagenet.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4 +cd /mnt/store/wbandar1/datasets +wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar --no-check-certificate +wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificate + +# +# script to extract ImageNet dataset +# ILSVRC2012_img_train.tar (about 138 GB) +# ILSVRC2012_img_val.tar (about 6.3 GB) +# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory +# +# https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md +# +# train/ +# ├── n01440764 +# │ ├── n01440764_10026.JPEG +# │ ├── n01440764_10027.JPEG +# │ ├── ...... +# ├── ...... +# val/ +# ├── n01440764 +# │ ├── ILSVRC2012_val_00000293.JPEG +# │ ├── ILSVRC2012_val_00002138.JPEG +# │ ├── ...... +# ├── ...... +# +# +# Extract the training data: +# +mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train +tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar +find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done +cd .. +# +# Extract the validation data and move images to subfolders: +# +mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar +wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash +# +# Check total files after extract +# +# $ find train/ -name "*.JPEG" | wc -l +# 1281167 +# $ find val/ -name "*.JPEG" | wc -l +# 50000 +# \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..c493d4a4a5ab8876cff5cde6a849db6f293b5cd1 --- /dev/null +++ b/environment.yml @@ -0,0 +1,188 @@ +name: ssl-aug +channels: + - pytorch + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - bottleneck=1.3.4=py38hce1f21e_0 + - brotlipy=0.7.0=py38h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2022.6.15=ha878542_0 + - cairo=1.16.0=hcf35c78_1003 + - certifi=2022.6.15=py38h578d9bd_0 + - cffi=1.15.0=py38h7f8727e_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cryptography=37.0.1=py38h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - dataclasses=0.8=pyh6d0b6a4_7 + - dbus=1.13.18=hb2f20db_0 + - expat=2.4.8=h27087fc_0 + - ffmpeg=4.3.2=hca11adc_0 + - fontconfig=2.14.0=h8e229c2_0 + - freetype=2.11.0=h70c0345_0 + - fvcore=0.1.5.post20220512=pyhd8ed1ab_0 + - gettext=0.19.8.1=hd7bead4_3 + - gh=2.12.1=ha8f183a_0 + - giflib=5.2.1=h7b6447c_0 + - glib=2.66.3=h58526e2_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - graphite2=1.3.14=h295c915_1 + - gst-plugins-base=1.14.5=h0935bb2_2 + - gstreamer=1.14.5=h36ae1b5_2 + - harfbuzz=2.4.0=h9f30f68_3 + - hdf5=1.10.6=hb1b8bf9_0 + - icu=64.2=he1b5a44_1 + - idna=3.3=pyhd3eb1b0_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - iopath=0.1.9=pyhd8ed1ab_0 + - jasper=1.900.1=hd497a04_4 + - jpeg=9e=h7f8727e_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - libblas=3.9.0=12_linux64_mkl + - libcblas=3.9.0=12_linux64_mkl + - libclang=9.0.1=default_hb4e5071_5 + - libedit=3.1.20210910=h7f8727e_0 + - libffi=3.2.1=hf484d3e_1007 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libglib=2.66.3=hbe7bbb4_0 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - liblapack=3.9.0=12_linux64_mkl + - liblapacke=3.9.0=12_linux64_mkl + - libllvm9=9.0.1=h4a3c616_1 + - libopencv=4.5.1=py38h703c3c0_0 + - libpng=1.6.37=hbc83047_0 + - libprotobuf=3.15.8=h780b84a_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.2.0=h2818925_1 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=2.32.1=h7f98852_1000 + - libuv=1.40.0=h7b6447c_0 + - libwebp=1.2.2=h55f646e_0 + - libwebp-base=1.2.2=h7f8727e_0 + - libxcb=1.15=h7f8727e_0 + - libxkbcommon=0.10.0=he1b5a44_0 + - libxml2=2.9.9=hea5a465_1 + - lz4-c=1.9.3=h295c915_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - ncurses=6.3=h7f8727e_2 + - nettle=3.7.3=hbbd107a_1 + - nspr=4.33=h295c915_0 + - nss=3.46.1=hab99668_0 + - numexpr=2.8.1=py38h6abb31d_0 + - numpy=1.22.3=py38he7a7128_0 + - numpy-base=1.22.3=py38hf524024_0 + - opencv=4.5.1=py38h578d9bd_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1o=h166bdaf_0 + - packaging=21.3=pyhd3eb1b0_0 + - pandas=1.4.2=py38h295c915_0 + - pcre=8.45=h295c915_0 + - pillow=9.0.1=py38h22f2fdc_0 + - pip=21.2.4=py38h06a4308_0 + - pixman=0.38.0=h7b6447c_0 + - portalocker=2.3.0=py38h06a4308_0 + - protobuf=3.15.8=py38h709712a_0 + - py-opencv=4.5.1=py38h81c977d_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - pyparsing=3.0.9=pyhd8ed1ab_0 + - pysocks=1.7.1=py38h06a4308_0 + - python=3.8.0=h0371630_2 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - python_abi=3.8=2_cp38 + - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0 + - pytorch-mutex=1.0=cuda + - pytz=2021.3=pyhd3eb1b0_0 + - pyyaml=6.0=py38h7f8727e_1 + - qt=5.12.5=hd8c4c69_1 + - readline=7.0=h7b6447c_5 + - requests=2.27.1=pyhd3eb1b0_0 + - setuptools=61.2.0=py38h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.33.0=h62c20be_0 + - tabulate=0.8.9=py38h06a4308_0 + - tensorboardx=2.5.1=pyhd8ed1ab_0 + - termcolor=1.1.0=py38h06a4308_1 + - tk=8.6.12=h1ccaba5_0 + - torchvision=0.12.0=py38_cu113 + - tqdm=4.64.0=py38h06a4308_0 + - typing_extensions=4.1.1=pyh06a4308_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - x264=1!161.3030=h7f98852_1 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.0.10=h7f98852_0 + - xorg-libsm=1.2.3=hd9c2040_1000 + - xorg-libx11=1.7.2=h7f98852_0 + - xorg-libxext=1.3.4=h7f98852_1 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h7f98852_1002 + - xorg-xproto=7.0.31=h27cfd23_1007 + - xz=5.2.5=h7f8727e_1 + - yacs=0.1.6=pyhd3eb1b0_1 + - yaml=0.2.5=h7b6447c_0 + - zip=3.0=h7f98852_1 + - zlib=1.2.12=h7f8727e_2 + - zstd=1.5.2=ha4553b6_0 + - pip: + - absl-py==1.1.0 + - appdirs==1.4.4 + - cachetools==5.2.0 + - click==8.1.7 + - contourpy==1.0.6 + - cycler==0.11.0 + - decord==0.6.0 + - deepspeed==0.5.8 + - docker-pycreds==0.4.0 + - einops==0.4.1 + - filelock==3.7.1 + - fonttools==4.38.0 + - future==0.18.2 + - gitdb==4.0.10 + - gitpython==3.1.33 + - google-auth==2.7.0 + - google-auth-oauthlib==0.4.6 + - grpcio==1.46.3 + - hjson==3.0.2 + - imageio==2.22.2 + - importlib-metadata==4.11.4 + - kiwisolver==1.4.4 + - markdown==3.3.7 + - matplotlib==3.6.1 + - ninja==1.10.2.3 + - oauthlib==3.2.0 + - pathtools==0.1.2 + - psutil==5.9.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - requests-oauthlib==1.3.1 + - rsa==4.8 + - scipy==1.9.0 + - sentry-sdk==1.30.0 + - setproctitle==1.3.2 + - smmap==5.0.0 + - tensorboard==2.9.1 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - thop==0.1.1-2209072238 + - timm==0.4.12 + - triton==1.1.1 + - urllib3==1.26.16 + - wandb==0.15.9 + - werkzeug==2.1.2 + - zipp==3.8.0 +prefix: /home/wbandar1/anaconda3/envs/ssl-aug diff --git a/evaluate_imagenet.py b/evaluate_imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..59d72294a5e58759de9fbeebfda505db8d163462 --- /dev/null +++ b/evaluate_imagenet.py @@ -0,0 +1,289 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import argparse +import json +import os +import random +import signal +import sys +import time +import urllib + +from torch import nn, optim +from torchvision import models, datasets, transforms +import torch +import torchvision +import wandb + +parser = argparse.ArgumentParser(description='Evaluate resnet50 features on ImageNet') +parser.add_argument('data', type=Path, metavar='DIR', + help='path to dataset') +parser.add_argument('pretrained', type=Path, metavar='FILE', + help='path to pretrained model') +parser.add_argument('--weights', default='freeze', type=str, + choices=('finetune', 'freeze'), + help='finetune or freeze resnet weights') +parser.add_argument('--train-percent', default=100, type=int, + choices=(100, 10, 1), + help='size of traing set in percent') +parser.add_argument('--workers', default=8, type=int, metavar='N', + help='number of data loader workers') +parser.add_argument('--epochs', default=100, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--batch-size', default=256, type=int, metavar='N', + help='mini-batch size') +parser.add_argument('--lr-backbone', default=0.0, type=float, metavar='LR', + help='backbone base learning rate') +parser.add_argument('--lr-classifier', default=0.3, type=float, metavar='LR', + help='classifier base learning rate') +parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W', + help='weight decay') +parser.add_argument('--print-freq', default=100, type=int, metavar='N', + help='print frequency') +parser.add_argument('--checkpoint-dir', default='/mnt/store/wbandar1/projects/ssl-aug-artifacts/', type=Path, + metavar='DIR', help='path to checkpoint directory') + + +def main(): + args = parser.parse_args() + if args.train_percent in {1, 10}: + args.train_files = urllib.request.urlopen(f'https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/{args.train_percent}percent.txt').readlines() + args.ngpus_per_node = torch.cuda.device_count() + if 'SLURM_JOB_ID' in os.environ: + signal.signal(signal.SIGUSR1, handle_sigusr1) + signal.signal(signal.SIGTERM, handle_sigterm) + # single-node distributed training + args.rank = 0 + args.dist_url = f'tcp://localhost:{random.randrange(49152, 65535)}' + args.world_size = args.ngpus_per_node + torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node) + + +def main_worker(gpu, args): + args.rank += gpu + torch.distributed.init_process_group( + backend='nccl', init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + # initializing wandb + if args.rank == 0: + run = wandb.init(project="bt-in1k-eval", config=args, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') + run_id = wandb.run.id + args.checkpoint_dir=Path(os.path.join(args.checkpoint_dir, run_id)) + + if args.rank == 0: + args.checkpoint_dir.mkdir(parents=True, exist_ok=True) + stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1) + print(' '.join(sys.argv)) + print(' '.join(sys.argv), file=stats_file) + + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + + model = models.resnet50().cuda(gpu) + state_dict = torch.load(args.pretrained, map_location='cpu') + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert missing_keys == ['fc.weight', 'fc.bias'] and unexpected_keys == [] + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + if args.weights == 'freeze': + model.requires_grad_(False) + model.fc.requires_grad_(True) + classifier_parameters, model_parameters = [], [] + for name, param in model.named_parameters(): + if name in {'fc.weight', 'fc.bias'}: + classifier_parameters.append(param) + else: + model_parameters.append(param) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) + + criterion = nn.CrossEntropyLoss().cuda(gpu) + + param_groups = [dict(params=classifier_parameters, lr=args.lr_classifier)] + if args.weights == 'finetune': + param_groups.append(dict(params=model_parameters, lr=args.lr_backbone)) + optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) + + # automatically resume from checkpoint if it exists + if (args.checkpoint_dir / 'checkpoint.pth').is_file(): + ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth', + map_location='cpu') + start_epoch = ckpt['epoch'] + best_acc = ckpt['best_acc'] + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optimizer']) + scheduler.load_state_dict(ckpt['scheduler']) + else: + start_epoch = 0 + best_acc = argparse.Namespace(top1=0, top5=0) + + # Data loading code + traindir = args.data / 'train' + valdir = args.data / 'val' + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder(traindir, transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + if args.train_percent in {1, 10}: + train_dataset.samples = [] + for fname in args.train_files: + fname = fname.decode().strip() + cls = fname.split('_')[0] + train_dataset.samples.append( + (traindir / cls / fname, train_dataset.class_to_idx[cls])) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + kwargs = dict(batch_size=args.batch_size // args.world_size, num_workers=args.workers, pin_memory=True) + train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, **kwargs) + val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs) + + start_time = time.time() + for epoch in range(start_epoch, args.epochs): + # train + if args.weights == 'finetune': + model.train() + elif args.weights == 'freeze': + model.eval() + else: + assert False + train_sampler.set_epoch(epoch) + for step, (images, target) in enumerate(train_loader, start=epoch * len(train_loader)): + output = model(images.cuda(gpu, non_blocking=True)) + loss = criterion(output, target.cuda(gpu, non_blocking=True)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if step % args.print_freq == 0: + torch.distributed.reduce(loss.div_(args.world_size), 0) + if args.rank == 0: + pg = optimizer.param_groups + lr_classifier = pg[0]['lr'] + lr_backbone = pg[1]['lr'] if len(pg) == 2 else 0 + stats = dict(epoch=epoch, step=step, lr_backbone=lr_backbone, + lr_classifier=lr_classifier, loss=loss.item(), + time=int(time.time() - start_time)) + print(json.dumps(stats)) + print(json.dumps(stats), file=stats_file) + run.log( + { + "epoch": epoch, + "step": step, + "lr_backbone": lr_backbone, + "lr_classifier": lr_classifier, + "loss": loss.item(), + "time": int(time.time() - start_time), + } + ) + + # evaluate + model.eval() + if args.rank == 0: + top1 = AverageMeter('Acc@1') + top5 = AverageMeter('Acc@5') + with torch.no_grad(): + for images, target in val_loader: + output = model(images.cuda(gpu, non_blocking=True)) + acc1, acc5 = accuracy(output, target.cuda(gpu, non_blocking=True), topk=(1, 5)) + top1.update(acc1[0].item(), images.size(0)) + top5.update(acc5[0].item(), images.size(0)) + best_acc.top1 = max(best_acc.top1, top1.avg) + best_acc.top5 = max(best_acc.top5, top5.avg) + stats = dict(epoch=epoch, acc1=top1.avg, acc5=top5.avg, best_acc1=best_acc.top1, best_acc5=best_acc.top5) + print(json.dumps(stats)) + print(json.dumps(stats), file=stats_file) + run.log( + { + "epoch": epoch, + "eval_acc1": top1.avg, + "eval_acc5": top5.avg, + "eval_best_acc1": best_acc.top1, + "eval_best_acc5": best_acc.top5, + } + ) + + # sanity check + if args.weights == 'freeze': + reference_state_dict = torch.load(args.pretrained, map_location='cpu') + model_state_dict = model.module.state_dict() + for k in reference_state_dict: + assert torch.equal(model_state_dict[k].cpu(), reference_state_dict[k]), k + + scheduler.step() + if args.rank == 0: + state = dict( + epoch=epoch + 1, best_acc=best_acc, model=model.state_dict(), + optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict()) + torch.save(state, args.checkpoint_dir / 'checkpoint.pth') + wandb.finish() + + +def handle_sigusr1(signum, frame): + os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') + exit() + + +def handle_sigterm(signum, frame): + pass + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/evaluate_transfer.py b/evaluate_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..18146e526027f4e521499ecac981b06f4e6d2350 --- /dev/null +++ b/evaluate_transfer.py @@ -0,0 +1,168 @@ +import argparse + +import pandas as pd +import torch +import torch.nn as nn +import torch.optim as optim +from thop import profile, clever_format +from torch.utils.data import DataLoader +from transfer_datasets import TRANSFER_DATASET +import torchvision.transforms as transforms +from data_statistics import get_data_mean_and_stdev, get_data_nclass +from tqdm import tqdm + +import utils + +import wandb + +import torchvision + +def load_transform(dataset, size=32): + mean, std = get_data_mean_and_stdev(dataset) + transform = transforms.Compose([ + transforms.Resize((size, size)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std)]) + return transform + +class Net(nn.Module): + def __init__(self, num_class, pretrained_path, dataset, arch): + super(Net, self).__init__() + + if arch=='resnet18': + embedding_size = 512 + elif arch=='resnet50': + embedding_size = 2048 + else: + raise NotImplementedError + + # encoder + from model import Model + self.f = Model(dataset=dataset, arch=arch).f + # classifier + self.fc = nn.Linear(embedding_size, num_class, bias=True) + self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) + + def forward(self, x): + x = self.f(x) + feature = torch.flatten(x, start_dim=1) + out = self.fc(feature) + return out + +# train or test for one epoch +def train_val(net, data_loader, train_optimizer): + is_train = train_optimizer is not None + net.train() if is_train else net.eval() + + total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) + with (torch.enable_grad() if is_train else torch.no_grad()): + for data, target in data_bar: + data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) + out = net(data) + loss = loss_criterion(out, target) + + if is_train: + train_optimizer.zero_grad() + loss.backward() + train_optimizer.step() + + total_num += data.size(0) + total_loss += loss.item() * data.size(0) + prediction = torch.argsort(out, dim=-1, descending=True) + total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + + data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}% model: {}' + .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num, + total_correct_1 / total_num * 100, total_correct_5 / total_num * 100, + model_path.split('/')[-1])) + + return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Linear Evaluation') + parser.add_argument('--dataset', default='cifar10', type=str, help='Pre-trained dataset.', choices=['cifar10', 'cifar100', 'stl10', 'tiny_imagenet']) + parser.add_argument('--transfer_dataset', default='cifar10', type=str, help='Transfer dataset (i.e., testing dataset)', choices=['cifar10', 'cifar100', 'stl-10', 'aircraft', 'cu_birds', 'dtd', 'fashionmnist', 'mnist', 'traffic_sign', 'vgg_flower']) + parser.add_argument('--arch', default='resnet50', type=str, help='Backbone architecture for experiments', choices=['resnet50', 'resnet18']) + parser.add_argument('--model_path', type=str, default='results/Barlow_Twins/0.005_64_128_model.pth', + help='The base string of the pretrained model path') + parser.add_argument('--batch_size', type=int, default=128, help='Number of images in each mini-batch') + parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train') + parser.add_argument('--screen', type=str, help='screen session id') + # wandb related args + parser.add_argument('--wandb_group', type=str, help='group for wandb') + + args = parser.parse_args() + + wandb.init(project=f"Barlow-Twins-MixUp-TransferLearn-[{args.dataset}-to-X]-{args.arch}", config=args, dir='/data/wbandar1/projects/ssl-aug-artifacts/wandb_logs/', group=args.wandb_group, name=f'{args.transfer_dataset}') + run_id = wandb.run.id + + model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs + dataset = args.dataset + transfer_dataset = args.transfer_dataset + + if dataset in ['cifar10', 'cifar100']: + print("reshaping data into 32x32") + resize = 32 + else: + print("reshaping data into 64x64") + resize = 64 + + train_data = TRANSFER_DATASET[args.transfer_dataset](train=True, image_transforms=load_transform(args.transfer_dataset, resize)) + test_data = TRANSFER_DATASET[args.transfer_dataset](train=False, image_transforms=load_transform(args.transfer_dataset, resize)) + + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) + + model = Net(num_class=get_data_nclass(args.transfer_dataset), pretrained_path=model_path, dataset=dataset, arch=args.arch).cuda() + for param in model.f.parameters(): + param.requires_grad = False + + # optimizer with lr sheduler + # lr_start, lr_end = 1e-2, 1e-6 + # gamma = (lr_end / lr_start) ** (1 / epochs) + # optimizer = optim.Adam(model.fc.parameters(), lr=lr_start, weight_decay=5e-6) + # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) + + # adpoted from + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 80], gamma=0.1) + + # optimizer with no sheuduler + # optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6) + + loss_criterion = nn.CrossEntropyLoss() + results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], + 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []} + + save_name = model_path.split('.pth')[0] + '_linear.csv' + + best_acc = 0.0 + for epoch in range(1, epochs + 1): + train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer) + results['train_loss'].append(train_loss) + results['train_acc@1'].append(train_acc_1) + results['train_acc@5'].append(train_acc_5) + test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None) + results['test_loss'].append(test_loss) + results['test_acc@1'].append(test_acc_1) + results['test_acc@5'].append(test_acc_5) + # save statistics + # data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) + # data_frame.to_csv(save_name, index_label='epoch') + if test_acc_1 > best_acc: + best_acc = test_acc_1 + wandb.log( + { + "train_loss": train_loss, + "train_acc@1": train_acc_1, + "train_acc@5": train_acc_5, + "test_loss": test_loss, + "test_acc@1": test_acc_1, + "test_acc@5": test_acc_5, + "best_acc": best_acc + } + ) + scheduler.step() + wandb.finish() diff --git a/figs/in-linear.png b/figs/in-linear.png new file mode 100644 index 0000000000000000000000000000000000000000..609d71b77872f85a995b77109e4cd73ec89e1eb7 Binary files /dev/null and b/figs/in-linear.png differ diff --git a/figs/in-loss-bt.png b/figs/in-loss-bt.png new file mode 100644 index 0000000000000000000000000000000000000000..3c69f3023d39f453ffc78602ae9dc837a7af557e Binary files /dev/null and b/figs/in-loss-bt.png differ diff --git a/figs/in-loss-reg.png b/figs/in-loss-reg.png new file mode 100644 index 0000000000000000000000000000000000000000..78df7c38cb303634f9ed55deb8d61f650026aa3d --- /dev/null +++ b/figs/in-loss-reg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab2e3e99017cd134a3f49878929bce151abcfa917cb8ceca436e401e2caeed4e +size 1268898 diff --git a/figs/mix-bt.jpg b/figs/mix-bt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8cae7967077e5182c809ec9333e50fe939dd8848 Binary files /dev/null and b/figs/mix-bt.jpg differ diff --git a/figs/mix-bt.svg b/figs/mix-bt.svg new file mode 100644 index 0000000000000000000000000000000000000000..7efd4e5e5a7b08969d0c28db44083d9d3dd288f7 --- /dev/null +++ b/figs/mix-bt.svg @@ -0,0 +1,3 @@ + + +
$$Y^...
$$Y^...
$$...
$$\...
$$\...
$$f_{e}$$
$$f_e$$
$$Z^...
$$Z^...
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
$$C^{A...
$$\m...
$$C^{AB} = (Z^A)^T...
,
,
$$f_...
$$f_...
$$Y^...
$$f_e$$
$$f_...
$$Z^...
$$\mathcal{L}_{reg} = \| C^{MA} - C^{MA}_{gt}...
MixUp Regularization
MixUp Regularization
(a)
(a)
(b)
(b)
Linear Mixing
Linear Mixing
Shuffle
Shuffle
$$({d\ti...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d45cbcd336cecc7775fd9f2ec0acbd11b3024d9 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchvision.models.resnet import resnet50 as _resnet50 + +dependencies = ['torch', 'torchvision'] + + +def resnet50(pretrained=True, **kwargs): + model = _resnet50(pretrained=False, **kwargs) + if pretrained: + url = 'https://dl.fbaipublicfiles.com/barlowtwins/ep1000_bs2048_lrw0.2_lrb0.0048_lambd0.0051/resnet50.pth' + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + return model \ No newline at end of file diff --git a/linear.py b/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..228110d04104a69629ec3df3c008d55479bb77cb --- /dev/null +++ b/linear.py @@ -0,0 +1,166 @@ +import argparse + +import pandas as pd +import torch +import torch.nn as nn +import torch.optim as optim +from thop import profile, clever_format +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10, CIFAR100 +from tqdm import tqdm + +import utils + +import wandb + +import torchvision + +class Net(nn.Module): + def __init__(self, num_class, pretrained_path, dataset, arch): + super(Net, self).__init__() + + if arch=='resnet18': + embedding_size = 512 + elif arch=='resnet50': + embedding_size = 2048 + else: + raise NotImplementedError + + # encoder + from model import Model + self.f = Model(dataset=dataset, arch=arch).f + # classifier + self.fc = nn.Linear(embedding_size, num_class, bias=True) + self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) + + def forward(self, x): + x = self.f(x) + feature = torch.flatten(x, start_dim=1) + out = self.fc(feature) + return out + +# train or test for one epoch +def train_val(net, data_loader, train_optimizer): + is_train = train_optimizer is not None + net.train() if is_train else net.eval() + + total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) + with (torch.enable_grad() if is_train else torch.no_grad()): + for data, target in data_bar: + data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) + out = net(data) + loss = loss_criterion(out, target) + + if is_train: + train_optimizer.zero_grad() + loss.backward() + train_optimizer.step() + + total_num += data.size(0) + total_loss += loss.item() * data.size(0) + prediction = torch.argsort(out, dim=-1, descending=True) + total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + + data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}% model: {}' + .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num, + total_correct_1 / total_num * 100, total_correct_5 / total_num * 100, + model_path.split('/')[-1])) + + return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Linear Evaluation') + parser.add_argument('--dataset', default='cifar10', type=str, help='Dataset: cifar10 or tiny_imagenet or stl10') + parser.add_argument('--arch', default='resnet50', type=str, help='Backbone architecture for experiments', choices=['resnet50', 'resnet18']) + parser.add_argument('--model_path', type=str, default='results/Barlow_Twins/0.005_64_128_model.pth', + help='The base string of the pretrained model path') + parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch') + parser.add_argument('--epochs', type=int, default=200, help='Number of sweeps over the dataset to train') + + args = parser.parse_args() + + wandb.init(project=f"Barlow-Twins-MixUp-Linear-{args.dataset}-{args.arch}", config=args, dir='/data/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') + run_id = wandb.run.id + + model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs + dataset = args.dataset + if dataset == 'cifar10': + train_data = CIFAR10(root='data', train=True,\ + transform=utils.CifarPairTransform(train_transform = True, pair_transform=False), download=True) + test_data = CIFAR10(root='data', train=False,\ + transform=utils.CifarPairTransform(train_transform = False, pair_transform=False), download=True) + if dataset == 'cifar100': + train_data = CIFAR100(root='data', train=True,\ + transform=utils.CifarPairTransform(train_transform = True, pair_transform=False), download=True) + test_data = CIFAR100(root='data', train=False,\ + transform=utils.CifarPairTransform(train_transform = False, pair_transform=False), download=True) + elif dataset == 'stl10': + train_data = torchvision.datasets.STL10(root='data', split="train", \ + transform=utils.StlPairTransform(train_transform = True, pair_transform=False), download=True) + test_data = torchvision.datasets.STL10(root='data', split="test", \ + transform=utils.StlPairTransform(train_transform = False, pair_transform=False), download=True) + elif dataset == 'tiny_imagenet': + train_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/train', \ + utils.TinyImageNetPairTransform(train_transform=True, pair_transform=False)) + test_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/val', \ + utils.TinyImageNetPairTransform(train_transform = False, pair_transform=False)) + + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) + + model = Net(num_class=len(train_data.classes), pretrained_path=model_path, dataset=dataset, arch=args.arch).cuda() + for param in model.f.parameters(): + param.requires_grad = False + + if dataset == 'cifar10' or dataset == 'cifar100': + flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) + elif dataset == 'tiny_imagenet' or dataset == 'stl10': + flops, params = profile(model, inputs=(torch.randn(1, 3, 64, 64).cuda(),)) + flops, params = clever_format([flops, params]) + print('# Model Params: {} FLOPs: {}'.format(params, flops)) + + # optimizer with lr sheduler + lr_start, lr_end = 1e-2, 1e-6 + gamma = (lr_end / lr_start) ** (1 / epochs) + optimizer = optim.Adam(model.fc.parameters(), lr=lr_start, weight_decay=5e-6) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) + + # optimizer with no sheuduler + # optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6) + + loss_criterion = nn.CrossEntropyLoss() + results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], + 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []} + + save_name = model_path.split('.pth')[0] + '_linear.csv' + + best_acc = 0.0 + for epoch in range(1, epochs + 1): + train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer) + scheduler.step() + results['train_loss'].append(train_loss) + results['train_acc@1'].append(train_acc_1) + results['train_acc@5'].append(train_acc_5) + test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None) + results['test_loss'].append(test_loss) + results['test_acc@1'].append(test_acc_1) + results['test_acc@5'].append(test_acc_5) + # save statistics + # data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) + # data_frame.to_csv(save_name, index_label='epoch') + #if test_acc_1 > best_acc: + # best_acc = test_acc_1 + # torch.save(model.state_dict(), 'results/linear_model.pth') + wandb.log( + { + "train_loss": train_loss, + "train_acc@1": train_acc_1, + "train_acc@5": train_acc_5, + "test_loss": test_loss, + "test_acc@1": test_acc_1, + "test_acc@5": test_acc_5 + } + ) + wandb.finish() diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..695920d9951ba5dd15bd905953a3f97299551150 --- /dev/null +++ b/main.py @@ -0,0 +1,271 @@ +import argparse +import os + +import pandas as pd +import torch +import numpy as np +import torch.optim as optim +import torch.nn.functional as F +from thop import profile, clever_format +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts +from tqdm import tqdm + +import utils +from model import Model +import math + +import torchvision + +import wandb + +if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + +def off_diagonal(x): + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + +def adjust_learning_rate(args, optimizer, loader, step): + max_steps = args.epochs * len(loader) + warmup_steps = 10 * len(loader) + base_lr = args.batch_size / 256 + if step < warmup_steps: + lr = base_lr * step / warmup_steps + else: + step -= warmup_steps + max_steps -= warmup_steps + q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + end_lr = base_lr * 0.001 + lr = base_lr * q + end_lr * (1 - q) + optimizer.param_groups[0]['lr'] = lr * args.lr + +def train(args, epoch, net, data_loader, train_optimizer): + net.train() + total_loss, total_loss_bt, total_loss_mix, total_num, train_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) + for step, data_tuple in enumerate(train_bar, start=epoch * len(train_bar)): + if args.lr_shed == "cosine": + adjust_learning_rate(args, train_optimizer, data_loader, step) + (pos_1, pos_2), _ = data_tuple + pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True) + _, out_1 = net(pos_1) + _, out_2 = net(pos_2) + + out_1_norm = (out_1 - out_1.mean(dim=0)) / out_1.std(dim=0) + out_2_norm = (out_2 - out_2.mean(dim=0)) / out_2.std(dim=0) + c = torch.matmul(out_1_norm.T, out_2_norm) / batch_size + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + loss_bt = on_diag + lmbda * off_diag + + ## MixUp (Our Contribution) ## + if args.is_mixup.lower() == 'true': + index = torch.randperm(batch_size).cuda(non_blocking=True) + alpha = np.random.beta(1.0, 1.0) + pos_m = alpha * pos_1 + (1 - alpha) * pos_2[index, :] + + _, out_m = net(pos_m) + out_m_norm = (out_m - out_m.mean(dim=0)) / out_m.std(dim=0) + + cc_m_1 = torch.matmul(out_m_norm.T, out_1_norm) / batch_size + cc_m_1_gt = alpha*torch.matmul(out_1_norm.T, out_1_norm) / batch_size + \ + (1-alpha)*torch.matmul(out_2_norm[index,:].T, out_1_norm) / batch_size + + cc_m_2 = torch.matmul(out_m_norm.T, out_2_norm) / batch_size + cc_m_2_gt = alpha*torch.matmul(out_1_norm.T, out_2_norm) / batch_size + \ + (1-alpha)*torch.matmul(out_2_norm[index,:].T, out_2_norm) / batch_size + + loss_mix = args.mixup_loss_scale*lmbda*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum()) + else: + loss_mix = torch.zeros(1).cuda() + ## MixUp (Our Contribution) ## + + loss = loss_bt + loss_mix + train_optimizer.zero_grad() + loss.backward() + train_optimizer.step() + + total_num += batch_size + total_loss += loss.item() * batch_size + total_loss_bt += loss_bt.item() * batch_size + total_loss_mix += loss_mix.item() * batch_size + + train_bar.set_description('Train Epoch: [{}/{}] lr: {:.3f}x10-3 Loss: {:.4f} lmbda:{:.4f} bsz:{} f_dim:{} dataset: {}'.format(\ + epoch, epochs, train_optimizer.param_groups[0]['lr'] * 1000, total_loss / total_num, lmbda, batch_size, feature_dim, dataset)) + return total_loss_bt / total_num, total_loss_mix / total_num, total_loss / total_num + + +def test(net, memory_data_loader, test_data_loader): + net.eval() + total_top1, total_top5, total_num, feature_bank, target_bank = 0.0, 0.0, 0, [], [] + with torch.no_grad(): + # generate feature bank and target bank + for data_tuple in tqdm(memory_data_loader, desc='Feature extracting'): + (data, _), target = data_tuple + target_bank.append(target) + feature, out = net(data.cuda(non_blocking=True)) + feature_bank.append(feature) + # [D, N] + feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() + # [N] + feature_labels = torch.cat(target_bank, dim=0).contiguous().to(feature_bank.device) + # loop test data to predict the label by weighted knn search + test_bar = tqdm(test_data_loader) + for data_tuple in test_bar: + (data, _), target = data_tuple + data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) + feature, out = net(data) + + total_num += data.size(0) + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices) + sim_weight = (sim_weight / temperature).exp() + + # counts for each class + one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' + .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100)) + return total_top1 / total_num * 100, total_top5 / total_num * 100 + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Training Barlow Twins') + parser.add_argument('--dataset', default='cifar10', type=str, help='Dataset: cifar10, cifar100, tiny_imagenet, stl10', choices=['cifar10', 'cifar100', 'tiny_imagenet', 'stl10']) + parser.add_argument('--arch', default='resnet50', type=str, help='Backbone architecture', choices=['resnet50', 'resnet18']) + parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for embedding vector') + parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax (kNN evaluation)') + parser.add_argument('--k', default=200, type=int, help='Top k most similar images used to predict the label') + parser.add_argument('--batch_size', default=512, type=int, help='Number of images in each mini-batch') + parser.add_argument('--epochs', default=1000, type=int, help='Number of sweeps over the dataset to train') + parser.add_argument('--lr', default=1e-3, type=float, help='Base learning rate') + parser.add_argument('--lr_shed', default="step", choices=["step", "cosine"], type=str, help='Learning rate scheduler: step / cosine') + + # for barlow twins + parser.add_argument('--lmbda', default=0.005, type=float, help='Lambda that controls the on- and off-diagonal terms') + parser.add_argument('--corr_neg_one', dest='corr_neg_one', action='store_true') + parser.add_argument('--corr_zero', dest='corr_neg_one', action='store_false') + parser.set_defaults(corr_neg_one=False) + + # for mixup + parser.add_argument('--is_mixup', dest='is_mixup', type=str, default='false', choices=['true', 'false']) + parser.add_argument('--mixup_loss_scale', dest='mixup_loss_scale', type=float, default=5.0) + + # GPU id (just for record) + parser.add_argument('--gpu', dest='gpu', type=int, default=0) + + args = parser.parse_args() + is_mixup = args.is_mixup.lower() == 'true' + + wandb.init(project=f"Barlow-Twins-MixUp-{args.dataset}-{args.arch}", config=args, dir='results/wandb_logs/') + run_id = wandb.run.id + dataset = args.dataset + feature_dim, temperature, k = args.feature_dim, args.temperature, args.k + batch_size, epochs = args.batch_size, args.epochs + lmbda = args.lmbda + corr_neg_one = args.corr_neg_one + + if dataset == 'cifar10': + train_data = torchvision.datasets.CIFAR10(root='/data/wbandar1/datasets', train=True, \ + transform=utils.CifarPairTransform(train_transform = True), download=True) + memory_data = torchvision.datasets.CIFAR10(root='/data/wbandar1/datasets', train=True, \ + transform=utils.CifarPairTransform(train_transform = False), download=True) + test_data = torchvision.datasets.CIFAR10(root='/data/wbandar1/datasets', train=False, \ + transform=utils.CifarPairTransform(train_transform = False), download=True) + elif dataset == 'cifar100': + train_data = torchvision.datasets.CIFAR100(root='/data/wbandar1/datasets', train=True, \ + transform=utils.CifarPairTransform(train_transform = True), download=True) + memory_data = torchvision.datasets.CIFAR100(root='/data/wbandar1/datasets', train=True, \ + transform=utils.CifarPairTransform(train_transform = False), download=True) + test_data = torchvision.datasets.CIFAR100(root='/data/wbandar1/datasets', train=False, \ + transform=utils.CifarPairTransform(train_transform = False), download=True) + elif dataset == 'stl10': + train_data = torchvision.datasets.STL10(root='/data/wbandar1/datasets', split="train+unlabeled", \ + transform=utils.StlPairTransform(train_transform = True), download=True) + memory_data = torchvision.datasets.STL10(root='/data/wbandar1/datasets', split="train", \ + transform=utils.StlPairTransform(train_transform = False), download=True) + test_data = torchvision.datasets.STL10(root='/data/wbandar1/datasets', split="test", \ + transform=utils.StlPairTransform(train_transform = False), download=True) + elif dataset == 'tiny_imagenet': + # download if not exits + if not os.path.isdir('/data/wbandar1/datasets/tiny-imagenet-200'): + raise ValueError("First preprocess the tinyimagenet dataset...") + + train_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/train', \ + utils.TinyImageNetPairTransform(train_transform = True)) + memory_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/train', \ + utils.TinyImageNetPairTransform(train_transform = False)) + test_data = torchvision.datasets.ImageFolder('/data/wbandar1/datasets/tiny-imagenet-200/val', \ + utils.TinyImageNetPairTransform(train_transform = False)) + + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, + drop_last=True) + memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) + + # model setup and optimizer config + model = Model(feature_dim, dataset, args.arch).cuda() + if dataset == 'cifar10' or dataset == 'cifar100': + flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) + elif dataset == 'tiny_imagenet' or dataset == 'stl10': + flops, params = profile(model, inputs=(torch.randn(1, 3, 64, 64).cuda(),)) + flops, params = clever_format([flops, params]) + print('# Model Params: {} FLOPs: {}'.format(params, flops)) + + optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6) + if args.lr_shed == "step": + m = [args.epochs - a for a in [50, 25]] + scheduler = MultiStepLR(optimizer, milestones=m, gamma=0.2) + c = len(memory_data.classes) + + results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []} + save_name_pre = '{}_{}_{}_{}_{}'.format(run_id, lmbda, feature_dim, batch_size, dataset) + run_id_dir = os.path.join('results/', run_id) + if not os.path.exists(run_id_dir): + print('Creating directory {}'.format(run_id_dir)) + os.mkdir(run_id_dir) + + best_acc = 0.0 + for epoch in range(1, epochs + 1): + loss_bt, loss_mix, train_loss = train(args, epoch, model, train_loader, optimizer) + if args.lr_shed == "step": + scheduler.step() + wandb.log( + { + "epoch": epoch, + "lr": optimizer.param_groups[0]['lr'], + "loss_bt": loss_bt, + "loss_mix": loss_mix, + "train_loss": train_loss} + ) + if epoch % 5 == 0: + test_acc_1, test_acc_5 = test(model, memory_loader, test_loader) + + results['train_loss'].append(train_loss) + results['test_acc@1'].append(test_acc_1) + results['test_acc@5'].append(test_acc_5) + data_frame = pd.DataFrame(data=results, index=range(5, epoch + 1, 5)) + data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch') + wandb.log( + { + "test_acc@1": test_acc_1, + "test_acc@5": test_acc_5 + } + ) + if test_acc_1 > best_acc: + best_acc = test_acc_1 + torch.save(model.state_dict(), 'results/{}/{}_model.pth'.format(run_id, save_name_pre)) + if epoch % 50 == 0: + torch.save(model.state_dict(), 'results/{}/{}_model_{}.pth'.format(run_id, save_name_pre, epoch)) + wandb.finish() diff --git a/main_imagenet.py b/main_imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..085545c612bdc7ab4341215c5f769a6c0427b3b1 --- /dev/null +++ b/main_imagenet.py @@ -0,0 +1,463 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import argparse +import json +import math +import os +import random +import signal +import subprocess +import sys +import time +import numpy as np +import wandb + +from PIL import Image, ImageOps, ImageFilter +from torch import nn, optim +import torch +import torchvision +import torchvision.transforms as transforms + +parser = argparse.ArgumentParser(description='Barlow Twins Training') +parser.add_argument('data', type=Path, metavar='DIR', + help='path to dataset') +parser.add_argument('--workers', default=8, type=int, metavar='N', + help='number of data loader workers') +parser.add_argument('--epochs', default=300, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--batch-size', default=512, type=int, metavar='N', + help='mini-batch size') +parser.add_argument('--learning-rate-weights', default=0.2, type=float, metavar='LR', + help='base learning rate for weights') +parser.add_argument('--learning-rate-biases', default=0.0048, type=float, metavar='LR', + help='base learning rate for biases and batch norm parameters') +parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W', + help='weight decay') +parser.add_argument('--lambd', default=0.0051, type=float, metavar='L', + help='weight on off-diagonal terms') +parser.add_argument('--projector', default='8192-8192-8192', type=str, + metavar='MLP', help='projector MLP') +parser.add_argument('--print-freq', default=1, type=int, metavar='N', + help='print frequency') +parser.add_argument('--checkpoint-dir', default='/mnt/store/wbandar1/projects/ssl-aug-artifacts/', type=Path, + metavar='DIR', help='path to checkpoint directory') +parser.add_argument('--is_mixup', default='false', type=str, + metavar='L', help='mixup regularization', choices=['true', 'false']) +parser.add_argument('--lambda_mixup', default=0.1, type=float, metavar='L', + help='Hyperparamter for the regularization loss') + +def main(): + args = parser.parse_args() + args.is_mixup = args.is_mixup.lower() == 'true' + args.ngpus_per_node = torch.cuda.device_count() + + run = wandb.init(project="Barlow-Twins-MixUp-ImageNet", config=args, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') + run_id = wandb.run.id + args.checkpoint_dir=Path(os.path.join(args.checkpoint_dir, run_id)) + + if 'SLURM_JOB_ID' in os.environ: + # single-node and multi-node distributed training on SLURM cluster + # requeue job on SLURM preemption + signal.signal(signal.SIGUSR1, handle_sigusr1) + signal.signal(signal.SIGTERM, handle_sigterm) + # find a common host name on all nodes + # assume scontrol returns hosts in the same order on all nodes + cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST') + stdout = subprocess.check_output(cmd.split()) + host_name = stdout.decode().splitlines()[0] + args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node + args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node + args.dist_url = f'tcp://{host_name}:58472' + else: + # single-node distributed training + args.rank = 0 + args.dist_url = 'tcp://localhost:58472' + args.world_size = args.ngpus_per_node + torch.multiprocessing.spawn(main_worker, (args,run,), args.ngpus_per_node) + wandb.finish() + + +def main_worker(gpu, args, run): + args.rank += gpu + torch.distributed.init_process_group( + backend='nccl', init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + if args.rank == 0: + args.checkpoint_dir.mkdir(parents=True, exist_ok=True) + stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1) + print(' '.join(sys.argv)) + print(' '.join(sys.argv), file=stats_file) + + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + + model = BarlowTwins(args).cuda(gpu) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + param_weights = [] + param_biases = [] + for param in model.parameters(): + if param.ndim == 1: + param_biases.append(param) + else: + param_weights.append(param) + parameters = [{'params': param_weights}, {'params': param_biases}] + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) + optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay, + weight_decay_filter=True, + lars_adaptation_filter=True) + + # automatically resume from checkpoint if it exists + if (args.checkpoint_dir / 'checkpoint.pth').is_file(): + ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth', + map_location='cpu') + start_epoch = ckpt['epoch'] + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optimizer']) + else: + start_epoch = 0 + + dataset = torchvision.datasets.ImageFolder(args.data / 'train', Transform()) + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + assert args.batch_size % args.world_size == 0 + per_device_batch_size = args.batch_size // args.world_size + loader = torch.utils.data.DataLoader( + dataset, batch_size=per_device_batch_size, num_workers=args.workers, + pin_memory=True, sampler=sampler) + + start_time = time.time() + scaler = torch.cuda.amp.GradScaler(growth_interval=100, enabled=True) + for epoch in range(start_epoch, args.epochs): + sampler.set_epoch(epoch) + for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)): + y1 = y1.cuda(gpu, non_blocking=True) + y2 = y2.cuda(gpu, non_blocking=True) + adjust_learning_rate(args, optimizer, loader, step) + mixup_loss_scale = adjust_mixup_scale(loader, step, args.lambda_mixup) + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=True): + loss_bt, loss_reg = model(y1, y2, args.is_mixup) + loss_regs = mixup_loss_scale * loss_reg + loss = loss_bt + loss_regs + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + if step % args.print_freq == 0: + if args.rank == 0: + stats = dict(epoch=epoch, step=step, + lr_weights=optimizer.param_groups[0]['lr'], + lr_biases=optimizer.param_groups[1]['lr'], + loss=loss.item(), + time=int(time.time() - start_time)) + print(json.dumps(stats)) + print(json.dumps(stats), file=stats_file) + if args.is_mixup: + run.log( + { + "epoch": epoch, + "step": step, + "lr_weights": optimizer.param_groups[0]['lr'], + "lr_biases": optimizer.param_groups[1]['lr'], + "loss": loss.item(), + "loss_bt": loss_bt.item(), + "loss_reg(unscaled)": loss_reg.item(), + "reg_scale": mixup_loss_scale, + "loss_reg(scaled)": loss_regs.item(), + "time": int(time.time() - start_time)} + ) + else: + run.log( + { + "epoch": epoch, + "step": step, + "lr_weights": optimizer.param_groups[0]['lr'], + "lr_biases": optimizer.param_groups[1]['lr'], + "loss": loss.item(), + "loss_bt": loss.item(), + "loss_reg(unscaled)": 0., + "reg_scale": 0., + "loss_reg(scaled)": 0., + "time": int(time.time() - start_time)} + ) + if args.rank == 0: + # save checkpoint + state = dict(epoch=epoch + 1, model=model.state_dict(), + optimizer=optimizer.state_dict()) + torch.save(state, args.checkpoint_dir / 'checkpoint.pth') + if args.rank == 0: + # save final model + print("Saving final model ...") + torch.save(model.module.backbone.state_dict(), + args.checkpoint_dir / 'resnet50.pth') + print("Finished saving final model ...") + + +def adjust_learning_rate(args, optimizer, loader, step): + max_steps = args.epochs * len(loader) + warmup_steps = 10 * len(loader) + base_lr = args.batch_size / 256 + if step < warmup_steps: + lr = base_lr * step / warmup_steps + else: + step -= warmup_steps + max_steps -= warmup_steps + q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + end_lr = base_lr * 0.001 + lr = base_lr * q + end_lr * (1 - q) + optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights + optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases + +def adjust_mixup_scale(loader, step, lambda_mixup): + warmup_steps = 10 * len(loader) + if step < warmup_steps: + return lambda_mixup * step / warmup_steps + else: + return lambda_mixup + +def handle_sigusr1(signum, frame): + os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') + exit() + + +def handle_sigterm(signum, frame): + pass + + +def off_diagonal(x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +class BarlowTwins(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.backbone = torchvision.models.resnet50(zero_init_residual=True) + self.backbone.fc = nn.Identity() + + # projector + sizes = [2048] + list(map(int, args.projector.split('-'))) + layers = [] + for i in range(len(sizes) - 2): + layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) + layers.append(nn.BatchNorm1d(sizes[i + 1])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) + self.projector = nn.Sequential(*layers) + + # normalization layer for the representations z1 and z2 + # self.bn = nn.BatchNorm1d(sizes[-1], affine=False) + + # def forward(self, y1, y2): + # z1 = self.projector(self.backbone(y1)) + # z2 = self.projector(self.backbone(y2)) + + # # empirical cross-correlation matrix + # c = self.bn(z1).T @ self.bn(z2) + + # # sum the cross-correlation matrix between all gpus + # c.div_(self.args.batch_size) + # torch.distributed.all_reduce(c) + + # on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + # off_diag = off_diagonal(c).pow_(2).sum() + # loss = on_diag + self.args.lambd * off_diag + # return loss + + def forward(self, y1, y2, is_mixup): + batch_size = y1.shape[0] + + ### original barlow twins ### + z1 = self.projector(self.backbone(y1)) + z2 = self.projector(self.backbone(y2)) + + # normilization + z1 = (z1 - z1.mean(dim=0)) / z1.std(dim=0) + z2 = (z2 - z2.mean(dim=0)) / z2.std(dim=0) + + # empirical cross-correlation matrix + c = z1.T @ z2 + + # sum the cross-correlation matrix between all gpus + c.div_(self.args.batch_size) + torch.distributed.all_reduce(c) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + loss = on_diag + self.args.lambd * off_diag + + if is_mixup: + ############################################## + ### mixup regularization: Implementation 1 ### + ############################################## + + # index = torch.randperm(batch_size).cuda(non_blocking=True) + # alpha = np.random.beta(1.0, 1.0) + # ym = alpha * y1 + (1. - alpha) * y2[index, :] + # zm = self.projector(self.backbone(ym)) + + # # normilization + # zm = (zm - zm.mean(dim=0)) / zm.std(dim=0) + + # # cc + # cc_m_1 = zm.T @ z1 + # cc_m_1.div_(batch_size) + # cc_m_1_gt = alpha*(z1.T @ z1) + (1.-alpha)*(z2[index,:].T @ z1) + # cc_m_1_gt.div_(batch_size) + + # cc_m_2 = zm.T @ z2 + # cc_m_2.div_(batch_size) + # cc_m_2_gt = alpha*(z2.T @ z2) + (1.-alpha)*(z2[index,:].T @ z2) + # cc_m_2_gt.div_(batch_size) + + # # mixup reg. loss + # lossm = 0.5*self.args.lambd*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum()) + + ############################################## + ### mixup regularization: Implementation 2 ### + ############################################## + index = torch.randperm(batch_size).cuda(non_blocking=True) + alpha = np.random.beta(1.0, 1.0) + ym = alpha * y1 + (1. - alpha) * y2[index, :] + zm = self.projector(self.backbone(ym)) + + # normilization + zm = (zm - zm.mean(dim=0)) / zm.std(dim=0) + + # cc + cc_m_1 = zm.T @ z1 + cc_m_1.div_(self.args.batch_size) + cc_m_1_gt = alpha*(z1.T @ z1) + (1.-alpha)*(z2[index,:].T @ z1) + cc_m_1_gt.div_(self.args.batch_size) + + cc_m_2 = zm.T @ z2 + cc_m_2.div_(self.args.batch_size) + cc_m_2_gt = alpha*(z2.T @ z2) + (1.-alpha)*(z2[index,:].T @ z2) + cc_m_2_gt.div_(self.args.batch_size) + + # gathering all cc + torch.distributed.all_reduce(cc_m_1) + torch.distributed.all_reduce(cc_m_1_gt) + torch.distributed.all_reduce(cc_m_2) + torch.distributed.all_reduce(cc_m_2_gt) + + # mixup reg. loss + lossm = 0.5*self.args.lambd*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum()) + else: + lossm = torch.zeros(1) + return loss, lossm + +class LARS(optim.Optimizer): + def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, + weight_decay_filter=False, lars_adaptation_filter=False): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, + eta=eta, weight_decay_filter=weight_decay_filter, + lars_adaptation_filter=lars_adaptation_filter) + super().__init__(params, defaults) + + + def exclude_bias_and_norm(self, p): + return p.ndim == 1 + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): + dp = dp.add(p, alpha=g['weight_decay']) + + if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['eta'] * param_norm / update_norm), one), one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + + p.add_(mu, alpha=-g['lr']) + + + +class GaussianBlur(object): + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + sigma = random.random() * 1.9 + 0.1 + return img.filter(ImageFilter.GaussianBlur(sigma)) + else: + return img + + +class Solarization(object): + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + + +class Transform: + def __init__(self): + self.transform = transforms.Compose([ + transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + GaussianBlur(p=1.0), + Solarization(p=0.0), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.transform_prime = transforms.Compose([ + transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + GaussianBlur(p=0.1), + Solarization(p=0.2), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __call__(self, x): + y1 = self.transform(x) + y2 = self.transform_prime(x) + return y1, y2 + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b3cd4165bdce3853b8fe42557a0f22bce8554132 --- /dev/null +++ b/model.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models.resnet import resnet50, resnet18 + + +class Model(nn.Module): + def __init__(self, feature_dim=128, dataset='cifar10', arch='resnet50'): + super(Model, self).__init__() + + self.f = [] + if arch == 'resnet18': + temp_model = resnet18().named_children() + embedding_size = 512 + elif arch == 'resnet50': + temp_model = resnet50().named_children() + embedding_size = 2048 + else: + raise NotImplementedError + + for name, module in temp_model: + if name == 'conv1': + module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + if dataset == 'cifar10' or dataset == 'cifar100': + if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): + self.f.append(module) + elif dataset == 'tiny_imagenet' or dataset == 'stl10': + if not isinstance(module, nn.Linear): + self.f.append(module) + # encoder + self.f = nn.Sequential(*self.f) + # projection head + self.g = nn.Sequential(nn.Linear(embedding_size, 512, bias=False), nn.BatchNorm1d(512), + nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) + + def forward(self, x): + x = self.f(x) + feature = torch.flatten(x, start_dim=1) + out = self.g(feature) + return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) diff --git a/preprocess_datasets/preprocess_tinyimagenet.sh b/preprocess_datasets/preprocess_tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..cdbe83cc60f8d9d05763d39804376b8c5446eba7 --- /dev/null +++ b/preprocess_datasets/preprocess_tinyimagenet.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# download and unzip dataset +cd /data/wbandar1/datasets +wget http://cs231n.stanford.edu/tiny-imagenet-200.zip +unzip tiny-imagenet-200.zip + +current="$(pwd)/tiny-imagenet-200" + +# training data +cd $current/train +for DIR in $(ls); do + cd $DIR + rm *.txt + mv images/* . + rm -r images + cd .. +done + +# validation data +cd $current/val +annotate_file="val_annotations.txt" +length=$(cat $annotate_file | wc -l) +for i in $(seq 1 $length); do + # fetch i th line + line=$(sed -n ${i}p $annotate_file) + # get file name and directory name + file=$(echo $line | cut -f1 -d" " ) + directory=$(echo $line | cut -f2 -d" ") + mkdir -p $directory + mv images/$file $directory +done +rm -r images +echo "done" \ No newline at end of file diff --git a/scripts-linear-resnet18/cifar10.sh b/scripts-linear-resnet18/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..9b2b573b6b386a9ad2a322e8eacc10012d9456e9 --- /dev/null +++ b/scripts-linear-resnet18/cifar10.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=cifar10 +arch=resnet18 +batch_size=512 +model_path=checkpoints/4wdhbpcf_0.0078125_1024_256_cifar10_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet18/cifar100.sh b/scripts-linear-resnet18/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..5ab2adcef428cf40981dd7390e93affced82026c --- /dev/null +++ b/scripts-linear-resnet18/cifar100.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=cifar100 +arch=resnet18 +batch_size=512 +model_path=checkpoints/76kk7scz_0.0078125_1024_256_cifar100_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-sug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet18/stl10.sh b/scripts-linear-resnet18/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..92bdff0e86b9312186142632dda2eea4835453e4 --- /dev/null +++ b/scripts-linear-resnet18/stl10.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=stl10 +arch=resnet18 +batch_size=512 +model_path=checkpoints/i7det4xq_0.0078125_1024_256_stl10_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-sug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet18/tinyimagenet.sh b/scripts-linear-resnet18/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..f14e204eece17499542f6c657c83697b0e6a2d11 --- /dev/null +++ b/scripts-linear-resnet18/tinyimagenet.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=tiny_imagenet +arch=resnet18 +batch_size=512 +model_path=checkpoints/02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-sug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet50/cifar10.sh b/scripts-linear-resnet50/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..79b2b086c04960a94de1e0313fdd059a5872f078 --- /dev/null +++ b/scripts-linear-resnet50/cifar10.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=cifar10 +arch=resnet50 +batch_size=512 +model_path=checkpoints/v3gwgusq_0.0078125_1024_256_cifar10_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet50/cifar100.sh b/scripts-linear-resnet50/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..92a265541ed38ac68adb456b3cfbf4f905bdcd2a --- /dev/null +++ b/scripts-linear-resnet50/cifar100.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=cifar100 +arch=resnet50 +batch_size=512 +model_path=checkpoints/z6ngefw7_0.0078125_1024_256_cifar100_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet50/imagenet_sup.sh b/scripts-linear-resnet50/imagenet_sup.sh new file mode 100644 index 0000000000000000000000000000000000000000..b4f5a45ff423eb5ed7917563e799112d34873a0c --- /dev/null +++ b/scripts-linear-resnet50/imagenet_sup.sh @@ -0,0 +1,11 @@ +#!/bin/bash +path_to_imagenet_data=datasets/imagenet1k/ +path_to_model=checkpoints/13awtq23_0.0051_8192_1024_imagenet_0.1_resnet50.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python evaluate_imagenet.py ${path_to_imagenet_data} ${path_to_model} --lr-classifier 0.3^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet50/stl10.sh b/scripts-linear-resnet50/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..5e4b73cf9ef17a88b0815c03e9fb645269822001 --- /dev/null +++ b/scripts-linear-resnet50/stl10.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=stl10 +arch=resnet50 +batch_size=512 +model_path=checkpoints/pbknx38b_0.0078125_1024_256_stl10_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-linear-resnet50/tinyimagenet.sh b/scripts-linear-resnet50/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..b05f986baae3984c23173a9bb910f2f289385fbf --- /dev/null +++ b/scripts-linear-resnet50/tinyimagenet.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=0 +dataset=tiny_imagenet +arch=resnet50 +batch_size=512 +model_path=checkpoints/kxlkigsv_0.0009765_1024_256_tiny_imagenet_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python linear.py --dataset ${dataset} --model_path ${model_path} --arch ${arch}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-pretrain-resnet18/cifar10.sh b/scripts-pretrain-resnet18/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..8b62fe4a45191e83106e6e6ed1b94a27bc63f670 --- /dev/null +++ b/scripts-pretrain-resnet18/cifar10.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# default: https://wandb.ai/cha-yas/Barlow-Twins-MixUp-cifar10-resnet18/runs/4wdhbpcf/overview?workspace=user-wgcban +gpu=0 +dataset=cifar10 +arch=resnet18 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 +lmbda=0.0078125 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet18/cifar100.sh b/scripts-pretrain-resnet18/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..35f81e0629a4c96869adfb20744dee877e440e71 --- /dev/null +++ b/scripts-pretrain-resnet18/cifar100.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=cifar100 +arch=resnet18 +feature_dim=2048 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 +lmbda=0.0078125 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet18/stl10.sh b/scripts-pretrain-resnet18/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..02e21d439cd2b28277942f59d2a0289d55a678c1 --- /dev/null +++ b/scripts-pretrain-resnet18/stl10.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=stl10 +arch=resnet18 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=2.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625pochs=2000 +lmbda=0.0078125 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet18/tinyimagenet.sh b/scripts-pretrain-resnet18/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..d279973f9696d626c4f6737507f87ec9c2d702e0 --- /dev/null +++ b/scripts-pretrain-resnet18/tinyimagenet.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=tiny_imagenet +arch=resnet18 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda +lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet50/cifar10.sh b/scripts-pretrain-resnet50/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..2019fe16d6cf20d0e8d7b0b7d78134cb9c663de6 --- /dev/null +++ b/scripts-pretrain-resnet50/cifar10.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=cifar10 +arch=resnet50 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=1000 +lr=0.01 +lr_shed=cosine # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 +lmbda=0.0078125 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet50/cifar100.sh b/scripts-pretrain-resnet50/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d06e1191dc5214047c9693eb052d2ee58f2357a --- /dev/null +++ b/scripts-pretrain-resnet50/cifar100.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=cifar100 +arch=resnet50 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=1000 +lr=0.01 +lr_shed=cosine # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 +lmbda=0.0078125 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet50/imagenet.sh b/scripts-pretrain-resnet50/imagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..742ba06e7fdf4d9cb15516ee3ced55126aeb5177 --- /dev/null +++ b/scripts-pretrain-resnet50/imagenet.sh @@ -0,0 +1,15 @@ +#!/bin/bash +is_mixup=true +batch_size=1024 #128/gpu works +lr_w=0.2 #0.2 +lr_b=0.0048 #0.0048 +lambda_mixup=1.0 + + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main_imagenet.py /data/wbandar1/datasets/imagenet1k/ --is_mixup ${is_mixup} --batch-size ${batch_size} --learning-rate-weights ${lr_w} --learning-rate-biases ${lr_b} --lambda_mixup ${lambda_mixup}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/scripts-pretrain-resnet50/stl10.sh b/scripts-pretrain-resnet50/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..7a9da099e851737cfb7cb3ff45b0f9f97fb7acad --- /dev/null +++ b/scripts-pretrain-resnet50/stl10.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=stl10 +arch=resnet50 +feature_dim=4096 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine # step, cosine +mixup_loss_scale=2.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 +lmbda=0.0078125 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-pretrain-resnet50/tinyimagenet.sh b/scripts-pretrain-resnet50/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..f9abde5f52d5db3c7f66fb0c9a3e0e1c96a08f11 --- /dev/null +++ b/scripts-pretrain-resnet50/tinyimagenet.sh @@ -0,0 +1,20 @@ +#!/bin/bash +gpu=0 +dataset=tiny_imagenet +arch=resnet50 +feature_dim=4096 +is_mixup=false # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda +lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-transfer-resnet18/cifar10-to-x.sh b/scripts-transfer-resnet18/cifar10-to-x.sh new file mode 100644 index 0000000000000000000000000000000000000000..eff60ca6732af6aa9661adaf94cf7b78f77f349a --- /dev/null +++ b/scripts-transfer-resnet18/cifar10-to-x.sh @@ -0,0 +1,28 @@ +#!/bin/bash +gpu=0 +dataset=cifar10 +arch=resnet18 +batch_size=128 +wandb_group='best-mbt' +model_path=checkpoints/4wdhbpcf_0.0078125_1024_256_cifar10_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +transfer_dataset='dtd' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='mnist' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='fashionmnist' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='cu_birds' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='vgg_flower' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='traffic_sign' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='aircraft' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-transfer-resnet18/cifar100-to-x.sh b/scripts-transfer-resnet18/cifar100-to-x.sh new file mode 100644 index 0000000000000000000000000000000000000000..8822ec7c62e11718772c9ea57708d2b735990b01 --- /dev/null +++ b/scripts-transfer-resnet18/cifar100-to-x.sh @@ -0,0 +1,28 @@ +#!/bin/bash +gpu=0 +dataset=cifar100 +arch=resnet18 +batch_size=128 +wandb_group='mbt' +model_path=checkpoints/76kk7scz_0.0078125_1024_256_cifar100_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +transfer_dataset='dtd' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='mnist' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='fashionmnist' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='cu_birds' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='vgg_flower' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='traffic_sign' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='aircraft' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/scripts-transfer-resnet18/stl10-to-x-bt.sh b/scripts-transfer-resnet18/stl10-to-x-bt.sh new file mode 100644 index 0000000000000000000000000000000000000000..15e90293bd5f205f1f6bd22817d1d817e7052674 --- /dev/null +++ b/scripts-transfer-resnet18/stl10-to-x-bt.sh @@ -0,0 +1,28 @@ +#!/bin/bash +gpu=0 +dataset=stl10 +arch=resnet18 +batch_size=128 +wandb_group='mbt' +model_path=checkpoints/i7det4xq_0.0078125_1024_256_stl10_model.pth + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +transfer_dataset='dtd' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='mnist' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='fashionmnist' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='cu_birds' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='vgg_flower' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='traffic_sign' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +transfer_dataset='aircraft' +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python evaluate_transfer.py --dataset ${dataset} --transfer_dataset ${transfer_dataset} --model_path ${model_path} --arch ${arch} --screen ${session_name} --wandb_group ${wandb_group}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ac4c3ae340199ff54cef3755d8636b25126bb54 --- /dev/null +++ b/setup.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +mkdir ssl-aug +mkdir Barlow-Twins-HSIC +mkdir /data/wbandar1/projects/ssl-aug-artifacts/results +git clone https://github.com/wgcban/ssl-aug.git Barlow-Twins-HSIC + +cd Barlow-Twins-HSIC +conda env create -f environment.yml +conda activate ssl-aug + + diff --git a/ssl-sota/README.md b/ssl-sota/README.md new file mode 100644 index 0000000000000000000000000000000000000000..da2ef91c8428dc6c8bf241e73dec8834319225b6 --- /dev/null +++ b/ssl-sota/README.md @@ -0,0 +1,87 @@ +# Self-Supervised Representation Learning + +Official repository of the paper **Whitening for Self-Supervised Representation Learning** + +ICML 2021 | [arXiv:2007.06346](https://arxiv.org/abs/2007.06346) + +It includes 3 types of losses: +- W-MSE [arXiv](https://arxiv.org/abs/2007.06346) +- Contrastive [SimCLR arXiv](https://arxiv.org/abs/2002.05709) +- BYOL [arXiv](https://arxiv.org/abs/2006.07733) + +And 5 datasets: +- CIFAR-10 and CIFAR-100 +- STL-10 +- Tiny ImageNet +- ImageNet-100 +Checkpoints are stored in `data` each 100 epochs during training. + +The implementation is optimized for a single GPU, although multiple are also supported. It includes fast evaluation: we pre-compute embeddings for the entire dataset and then train a classifier on top. The evaluation of the ResNet-18 encoder takes about one minute. + +## Installation + +The implementation is based on PyTorch. Logging works on [wandb.ai](https://wandb.ai/). See `docker/Dockerfile`. + +#### ImageNet-100 +To get this dataset, take the original ImageNet and filter out [this subset of classes](https://github.com/HobbitLong/CMC/blob/master/imagenet100.txt). We do not use augmentations during testing, and loading big images with resizing on the fly is slow, so we can preprocess classifier train and test images. We recommend [mogrify](https://imagemagick.org/script/mogrify.php) for it. First, you need to resize to 256 (just like `torchvision.transforms.Resize(256)`) and then crop to 224 (like `torchvision.transforms.CenterCrop(224)`). Finally, put the original images to `train`, and resized to `clf` and `test`. + +## Usage + +Detailed settings are good by default, to see all options: +``` +python -m train --help +python -m test --help +``` + +To reproduce the results from [table 1](https://arxiv.org/abs/2007.06346): +#### W-MSE 4 +``` +python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --num_samples 4 --bs 256 --emb 64 --w_size 128 +python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --num_samples 4 --bs 256 --emb 64 --w_size 128 +python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 +python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 +``` + +#### W-MSE 2 +``` +python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --w_size 128 +python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --w_size 128 +python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --w_size 256 --w_iter 4 +python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --w_size 256 --w_iter 4 +``` + +#### Contrastive +``` +python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --method contrastive --arch resnet50 +python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --method contrastive --arch resnet50 +python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --method contrastive --arch resnet50 +python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --method contrastive --arch resnet50 +``` + +#### BYOL +``` +python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --method byol +python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --method byol +python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --method byol +python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --method byol +``` + +#### ImageNet-100 +``` +python -m train --dataset imagenet --epoch 240 --lr 2e-3 --emb 128 --w_size 256 --crop_s0 0.08 --cj0 0.8 --cj1 0.8 --cj2 0.8 --cj3 0.2 --gs_p 0.2 +python -m train --dataset imagenet --epoch 240 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 --crop_s0 0.08 --cj0 0.8 --cj1 0.8 --cj2 0.8 --cj3 0.2 --gs_p 0.2 +``` + +Use `--no_norm` to disable normalization (for Euclidean distance). + +## Citation +``` +@inproceedings{ermolov2021whitening, + title={Whitening for self-supervised representation learning}, + author={Ermolov, Aleksandr and Siarohin, Aliaksandr and Sangineto, Enver and Sebe, Nicu}, + booktitle={International Conference on Machine Learning}, + pages={3015--3024}, + year={2021}, + organization={PMLR} +} +``` diff --git a/ssl-sota/cfg.py b/ssl-sota/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..1a920cde219f5003f26e06e3c0f3e0375aaf284d --- /dev/null +++ b/ssl-sota/cfg.py @@ -0,0 +1,152 @@ +from functools import partial +import argparse +from torchvision import models +import multiprocessing +from datasets import DS_LIST +from methods import METHOD_LIST + + +def get_cfg(): + """ generates configuration from user input in console """ + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type", + ) + parser.add_argument( + "--wandb", + type=str, + default="ssl-sota", + help="name of the project for logging at https://wandb.ai", + ) + parser.add_argument( + "--byol_tau", type=float, default=0.99, help="starting tau for byol loss" + ) + parser.add_argument( + "--num_samples", + type=int, + default=2, + help="number of samples (d) generated from each image", + ) + + addf = partial(parser.add_argument, type=float) + addf("--cj0", default=0.4, help="color jitter brightness") + addf("--cj1", default=0.4, help="color jitter contrast") + addf("--cj2", default=0.4, help="color jitter saturation") + addf("--cj3", default=0.1, help="color jitter hue") + addf("--cj_p", default=0.8, help="color jitter probability") + addf("--gs_p", default=0.1, help="grayscale probability") + addf("--crop_s0", default=0.2, help="crop size from") + addf("--crop_s1", default=1.0, help="crop size to") + addf("--crop_r0", default=0.75, help="crop ratio from") + addf("--crop_r1", default=(4 / 3), help="crop ratio to") + addf("--hf_p", default=0.5, help="horizontal flip probability") + + parser.add_argument( + "--no_lr_warmup", + dest="lr_warmup", + action="store_false", + help="do not use learning rate warmup", + ) + parser.add_argument( + "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head" + ) + parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier") + parser.add_argument("--fname", type=str, help="load model from file") + parser.add_argument( + "--lr_step", + type=str, + choices=["cos", "step", "none"], + default="step", + help="learning rate schedule type", + ) + parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") + parser.add_argument( + "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)" + ) + parser.add_argument( + "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)" + ) + parser.add_argument("--T0", type=int, help="period (for --lr_step cos)") + parser.add_argument( + "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)" + ) + parser.add_argument( + "--w_eps", type=float, default=1e-4, help="eps for stability for whitening" + ) + parser.add_argument( + "--head_layers", type=int, default=2, help="number of FC layers in head" + ) + parser.add_argument( + "--head_size", type=int, default=1024, help="size of FC layers in head" + ) + + parser.add_argument( + "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss" + ) + parser.add_argument( + "--w_iter", + type=int, + default=1, + help="iterations for whitening matrix estimation", + ) + + parser.add_argument( + "--no_norm", dest="norm", action="store_false", help="don't normalize latents", + ) + parser.add_argument( + "--tau", type=float, default=0.5, help="contrastive loss temperature" + ) + + parser.add_argument("--epoch", type=int, default=200, help="total epoch number") + parser.add_argument( + "--eval_every_drop", + type=int, + default=5, + help="how often to evaluate after learning rate drop", + ) + parser.add_argument( + "--eval_every", type=int, default=20, help="how often to evaluate" + ) + parser.add_argument("--emb", type=int, default=64, help="embedding size") + parser.add_argument( + "--bs", type=int, default=384, help="number of original images in batch N", + ) + parser.add_argument( + "--drop", + type=int, + nargs="*", + default=[50, 25], + help="milestones for learning rate decay (0 = last epoch)", + ) + parser.add_argument( + "--drop_gamma", + type=float, + default=0.2, + help="multiplicative factor of learning rate decay", + ) + parser.add_argument( + "--arch", + type=str, + choices=[x for x in dir(models) if "resn" in x], + default="resnet18", + help="encoder architecture", + ) + parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10") + parser.add_argument( + "--num_workers", + type=int, + default=0, + help="dataset workers number", + ) + parser.add_argument( + "--clf", + type=str, + default="sgd", + choices=["sgd", "knn", "lbfgs"], + help="classifier for test.py", + ) + parser.add_argument( + "--eval_head", action="store_true", help="eval head output instead of model", + ) + parser.add_argument("--imagenet_path", type=str, default="~/IN100/") + return parser.parse_args() diff --git a/ssl-sota/datasets/__init__.py b/ssl-sota/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a72e00b827c4cbba69c3f69f6d7d8ed513bfb660 --- /dev/null +++ b/ssl-sota/datasets/__init__.py @@ -0,0 +1,22 @@ +from .cifar10 import CIFAR10 +from .cifar100 import CIFAR100 +from .stl10 import STL10 +from .tiny_in import TinyImageNet +from .imagenet import ImageNet + + +DS_LIST = ["cifar10", "cifar100", "stl10", "tinyimagenet", "imagenet"] + + +def get_ds(name): + assert name in DS_LIST + if name == "cifar10": + return CIFAR10 + elif name == "cifar100": + return CIFAR100 + elif name == "stl10": + return STL10 + elif name == "tinyimagenet": + return TinyImageNet + elif name == "imagenet": + return ImageNet diff --git a/ssl-sota/datasets/base.py b/ssl-sota/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..35a34efb94d7b3d4146b279b88e1bf57cbf9bbea --- /dev/null +++ b/ssl-sota/datasets/base.py @@ -0,0 +1,67 @@ +from abc import ABCMeta, abstractmethod +from functools import lru_cache +from torch.utils.data import DataLoader + + +class BaseDataset(metaclass=ABCMeta): + """ + base class for datasets, it includes 3 types: + - for self-supervised training, + - for classifier training for evaluation, + - for testing + """ + + def __init__( + self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000, + ): + self.aug_cfg = aug_cfg + self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test + self.num_workers = num_workers + + @abstractmethod + def ds_train(self): + raise NotImplementedError + + @abstractmethod + def ds_clf(self): + raise NotImplementedError + + @abstractmethod + def ds_test(self): + raise NotImplementedError + + @property + @lru_cache() + def train(self): + return DataLoader( + dataset=self.ds_train(), + batch_size=self.bs_train, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + drop_last=True, + ) + + @property + @lru_cache() + def clf(self): + return DataLoader( + dataset=self.ds_clf(), + batch_size=self.bs_clf, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + drop_last=True, + ) + + @property + @lru_cache() + def test(self): + return DataLoader( + dataset=self.ds_test(), + batch_size=self.bs_test, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + drop_last=False, + ) diff --git a/ssl-sota/datasets/cifar10.py b/ssl-sota/datasets/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..ef8aafae9087b611381467a16132825896a8b77f --- /dev/null +++ b/ssl-sota/datasets/cifar10.py @@ -0,0 +1,26 @@ +from torchvision.datasets import CIFAR10 as C10 +import torchvision.transforms as T +from .transforms import MultiSample, aug_transform +from .base import BaseDataset + + +def base_transform(): + return T.Compose( + [T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] + ) + + +class CIFAR10(BaseDataset): + def ds_train(self): + t = MultiSample( + aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples + ) + return C10(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t) + + def ds_clf(self): + t = base_transform() + return C10(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t) + + def ds_test(self): + t = base_transform() + return C10(root="/mnt/store/wbandar1/datasets/", train=False, download=True, transform=t) diff --git a/ssl-sota/datasets/cifar100.py b/ssl-sota/datasets/cifar100.py new file mode 100644 index 0000000000000000000000000000000000000000..67e4173e705ac4b9443768713a9025b108ccc155 --- /dev/null +++ b/ssl-sota/datasets/cifar100.py @@ -0,0 +1,26 @@ +from torchvision.datasets import CIFAR100 as C100 +import torchvision.transforms as T +from .transforms import MultiSample, aug_transform +from .base import BaseDataset + + +def base_transform(): + return T.Compose( + [T.ToTensor(), T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))] + ) + + +class CIFAR100(BaseDataset): + def ds_train(self): + t = MultiSample( + aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples + ) + return C100(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t,) + + def ds_clf(self): + t = base_transform() + return C100(root="/mnt/store/wbandar1/datasets/", train=True, download=True, transform=t) + + def ds_test(self): + t = base_transform() + return C100(root="/mnt/store/wbandar1/datasets/", train=False, download=True, transform=t) diff --git a/ssl-sota/datasets/imagenet.py b/ssl-sota/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..5677db1d4f69c9eb0a3f12dd7ff45a4982061e49 --- /dev/null +++ b/ssl-sota/datasets/imagenet.py @@ -0,0 +1,41 @@ +import random +from torchvision.datasets import ImageFolder +import torchvision.transforms as T +from PIL import ImageFilter +from .transforms import MultiSample, aug_transform +from .base import BaseDataset + + +class RandomBlur: + def __init__(self, r0, r1): + self.r0, self.r1 = r0, r1 + + def __call__(self, image): + r = random.uniform(self.r0, self.r1) + return image.filter(ImageFilter.GaussianBlur(radius=r)) + + +def base_transform(): + return T.Compose( + [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] + ) + + +class ImageNet(BaseDataset): + def ds_train(self): + aug_with_blur = aug_transform( + 224, + base_transform, + self.aug_cfg, + extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)], + ) + t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples) + return ImageFolder(root=self.aug_cfg.imagenet_path + "train", transform=t) + + def ds_clf(self): + t = base_transform() + return ImageFolder(root=self.aug_cfg.imagenet_path + "clf", transform=t) + + def ds_test(self): + t = base_transform() + return ImageFolder(root=self.aug_cfg.imagenet_path + "test", transform=t) diff --git a/ssl-sota/datasets/stl10.py b/ssl-sota/datasets/stl10.py new file mode 100644 index 0000000000000000000000000000000000000000..33f2c50f5cd242bb58652fbf67f15e79a9423aba --- /dev/null +++ b/ssl-sota/datasets/stl10.py @@ -0,0 +1,32 @@ +from torchvision.datasets import STL10 as S10 +import torchvision.transforms as T +from .transforms import MultiSample, aug_transform +from .base import BaseDataset + + +def base_transform(): + return T.Compose( + [T.ToTensor(), T.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))] + ) + + +def test_transform(): + return T.Compose( + [T.Resize(70, interpolation=3), T.CenterCrop(64), base_transform()] + ) + + +class STL10(BaseDataset): + def ds_train(self): + t = MultiSample( + aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples + ) + return S10(root="/mnt/store/wbandar1/datasets/", split="train+unlabeled", download=True, transform=t) + + def ds_clf(self): + t = test_transform() + return S10(root="/mnt/store/wbandar1/datasets/", split="train", download=True, transform=t) + + def ds_test(self): + t = test_transform() + return S10(root="/mnt/store/wbandar1/datasets/", split="test", download=True, transform=t) diff --git a/ssl-sota/datasets/tiny_in.py b/ssl-sota/datasets/tiny_in.py new file mode 100644 index 0000000000000000000000000000000000000000..42bc9ae0b819c5ce9fbaffffb545abaf4253aad9 --- /dev/null +++ b/ssl-sota/datasets/tiny_in.py @@ -0,0 +1,26 @@ +from torchvision.datasets import ImageFolder +import torchvision.transforms as T +from .transforms import MultiSample, aug_transform +from .base import BaseDataset + + +def base_transform(): + return T.Compose( + [T.ToTensor(), T.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))] + ) + + +class TinyImageNet(BaseDataset): + def ds_train(self): + t = MultiSample( + aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples + ) + return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/train", transform=t) + + def ds_clf(self): + t = base_transform() + return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/train", transform=t) + + def ds_test(self): + t = base_transform() + return ImageFolder(root="/mnt/store/wbandar1/datasets/tiny-imagenet-200/val", transform=t) diff --git a/ssl-sota/datasets/transforms.py b/ssl-sota/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..743deb7024276f8510b6ee7db9e2898c933a43e6 --- /dev/null +++ b/ssl-sota/datasets/transforms.py @@ -0,0 +1,33 @@ +import torchvision.transforms as T + + +def aug_transform(crop, base_transform, cfg, extra_t=[]): + """ augmentation transform generated from config """ + return T.Compose( + [ + T.RandomApply( + [T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p + ), + T.RandomGrayscale(p=cfg.gs_p), + T.RandomResizedCrop( + crop, + scale=(cfg.crop_s0, cfg.crop_s1), + ratio=(cfg.crop_r0, cfg.crop_r1), + interpolation=3, + ), + T.RandomHorizontalFlip(p=cfg.hf_p), + *extra_t, + base_transform(), + ] + ) + + +class MultiSample: + """ generates n samples with augmentation """ + + def __init__(self, transform, n=2): + self.transform = transform + self.num = n + + def __call__(self, x): + return tuple(self.transform(x) for _ in range(self.num)) diff --git a/ssl-sota/docker/Dockerfile b/ssl-sota/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..32c18303107f1ac4c563576df4a428f853b7230d --- /dev/null +++ b/ssl-sota/docker/Dockerfile @@ -0,0 +1,6 @@ +FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime +RUN pip install sklearn opencv-python +RUN pip install matplotlib +RUN pip install wandb +RUN pip install ipdb +ENTRYPOINT wandb login $WANDB_KEY && /bin/bash diff --git a/ssl-sota/eval/get_data.py b/ssl-sota/eval/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1d4e6e29957a41f963b4fa98de7bafeb23542a --- /dev/null +++ b/ssl-sota/eval/get_data.py @@ -0,0 +1,17 @@ +import torch + + +def get_data(model, loader, output_size, device): + """ encodes whole dataset into embeddings """ + xs = torch.empty( + len(loader), loader.batch_size, output_size, dtype=torch.float32, device=device + ) + ys = torch.empty(len(loader), loader.batch_size, dtype=torch.long, device=device) + with torch.no_grad(): + for i, (x, y) in enumerate(loader): + x = x.cuda() + xs[i] = model(x).to(device) + ys[i] = y.to(device) + xs = xs.view(-1, output_size) + ys = ys.view(-1) + return xs, ys diff --git a/ssl-sota/eval/knn.py b/ssl-sota/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..b26709abf74b146eb7538a297ea0ea68c56b3663 --- /dev/null +++ b/ssl-sota/eval/knn.py @@ -0,0 +1,16 @@ +import torch + + +def eval_knn(x_train, y_train, x_test, y_test, k=200): + """ k-nearest neighbors classifier accuracy """ + d = torch.cdist(x_test, x_train) + topk = torch.topk(d, k=k, dim=1, largest=False) + labels = y_train[topk.indices] + pred = torch.empty_like(y_test) + for i in range(len(labels)): + x = labels[i].unique(return_counts=True) + pred[i] = x[0][x[1].argmax()] + + acc = (pred == y_test).float().mean().cpu().item() + del d, topk, labels, pred + return acc diff --git a/ssl-sota/eval/lbfgs.py b/ssl-sota/eval/lbfgs.py new file mode 100644 index 0000000000000000000000000000000000000000..45b4d7eb949f4c90550b3b09ad50aac0ee932b7b --- /dev/null +++ b/ssl-sota/eval/lbfgs.py @@ -0,0 +1,12 @@ +import torch +from sklearn.linear_model import LogisticRegression + + +def eval_lbfgs(x_train, y_train, x_test, y_test): + """ linear classifier accuracy (lbfgs method) """ + clf = LogisticRegression( + random_state=1337, solver="lbfgs", max_iter=1000, n_jobs=-1 + ) + clf.fit(x_train, y_train) + pred = clf.predict(x_test) + return (torch.tensor(pred) == y_test).float().mean() diff --git a/ssl-sota/eval/sgd.py b/ssl-sota/eval/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..908405288f768142f11526a42ee4b8fef4ef7f24 --- /dev/null +++ b/ssl-sota/eval/sgd.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.optim as optim + + +def eval_sgd(x_train, y_train, x_test, y_test, topk=[1, 5], epoch=500): + """ linear classifier accuracy (sgd) """ + lr_start, lr_end = 1e-2, 1e-6 + gamma = (lr_end / lr_start) ** (1 / epoch) + output_size = x_train.shape[1] + num_class = y_train.max().item() + 1 + clf = nn.Linear(output_size, num_class) + clf.cuda() + clf.train() + optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) + criterion = nn.CrossEntropyLoss() + + for ep in range(epoch): + perm = torch.randperm(len(x_train)).view(-1, 1000) + for idx in perm: + optimizer.zero_grad() + criterion(clf(x_train[idx]), y_train[idx]).backward() + optimizer.step() + scheduler.step() + + clf.eval() + with torch.no_grad(): + y_pred = clf(x_test) + pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices + acc = { + t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() + for t in topk + } + del clf + return acc diff --git a/ssl-sota/methods/__init__.py b/ssl-sota/methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7781711381539a4e7e29ae504b6073eb05434f9d --- /dev/null +++ b/ssl-sota/methods/__init__.py @@ -0,0 +1,16 @@ +from .contrastive import Contrastive +from .w_mse import WMSE +from .byol import BYOL + + +METHOD_LIST = ["contrastive", "w_mse", "byol"] + + +def get_method(name): + assert name in METHOD_LIST + if name == "contrastive": + return Contrastive + elif name == "w_mse": + return WMSE + elif name == "byol": + return BYOL diff --git a/ssl-sota/methods/base.py b/ssl-sota/methods/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad54a5ea7b985789b2c5fd09580919c0feba121 --- /dev/null +++ b/ssl-sota/methods/base.py @@ -0,0 +1,61 @@ +import torch.nn as nn +from model import get_model, get_head +from eval.sgd import eval_sgd +from eval.knn import eval_knn +from eval.get_data import get_data +import torch +import tqdm + +class BaseMethod(nn.Module): + """ + Base class for self-supervised loss implementation. + It includes encoder and head for training, evaluation function. + """ + + def __init__(self, cfg): + super().__init__() + self.model, self.out_size = get_model(cfg.arch, cfg.dataset) + self.head = get_head(self.out_size, cfg) + self.knn = cfg.knn + self.num_pairs = cfg.num_samples * (cfg.num_samples - 1) // 2 + self.eval_head = cfg.eval_head + self.emb_size = cfg.emb + + def forward(self, samples): + raise NotImplementedError + + def get_acc(self, ds_clf, ds_test): + self.eval() + if self.eval_head: + model = lambda x: self.head(self.model(x)) + out_size = self.emb_size + else: + model, out_size = self.model, self.out_size + # torch.cuda.empty_cache() + x_train, y_train = get_data(model, ds_clf, out_size, "cuda") + x_test, y_test = get_data(model, ds_test, out_size, "cuda") + + acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn) + acc_linear = eval_sgd(x_train, y_train, x_test, y_test) + del x_train, y_train, x_test, y_test + self.train() + return acc_knn, acc_linear + + def get_acc_knn(self, ds_clf, ds_test): + self.eval() + if self.eval_head: + model = lambda x: self.head(self.model(x)) + out_size = self.emb_size + else: + model, out_size = self.model, self.out_size + # torch.cuda.empty_cache() + x_train, y_train = get_data(model, ds_clf, out_size, "cuda") + x_test, y_test = get_data(model, ds_test, out_size, "cuda") + + acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn) + del x_train, y_train, x_test, y_test + self.train() + return acc_knn + + def step(self, progress): + pass diff --git a/ssl-sota/methods/byol.py b/ssl-sota/methods/byol.py new file mode 100644 index 0000000000000000000000000000000000000000..de8c15ece6c69ab120e431e6e8039e094562abb5 --- /dev/null +++ b/ssl-sota/methods/byol.py @@ -0,0 +1,53 @@ +from itertools import chain +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from model import get_model, get_head +from .base import BaseMethod +from .norm_mse import norm_mse_loss + + +class BYOL(BaseMethod): + """ implements BYOL loss https://arxiv.org/abs/2006.07733 """ + + def __init__(self, cfg): + """ init additional target and predictor networks """ + super().__init__(cfg) + self.pred = nn.Sequential( + nn.Linear(cfg.emb, cfg.head_size), + nn.BatchNorm1d(cfg.head_size), + nn.ReLU(), + nn.Linear(cfg.head_size, cfg.emb), + ) + self.model_t, _ = get_model(cfg.arch, cfg.dataset) + self.head_t = get_head(self.out_size, cfg) + for param in chain(self.model_t.parameters(), self.head_t.parameters()): + param.requires_grad = False + self.update_target(0) + self.byol_tau = cfg.byol_tau + self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss + + def update_target(self, tau): + """ copy parameters from main network to target """ + for t, s in zip(self.model_t.parameters(), self.model.parameters()): + t.data.copy_(t.data * tau + s.data * (1.0 - tau)) + for t, s in zip(self.head_t.parameters(), self.head.parameters()): + t.data.copy_(t.data * tau + s.data * (1.0 - tau)) + + def forward(self, samples): + z = [self.pred(self.head(self.model(x))) for x in samples] + with torch.no_grad(): + zt = [self.head_t(self.model_t(x)) for x in samples] + + loss = 0 + for i in range(len(samples) - 1): + for j in range(i + 1, len(samples)): + loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i]) + loss /= self.num_pairs + return loss + + def step(self, progress): + """ update target network with cosine increasing schedule """ + tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2 + self.update_target(tau) diff --git a/ssl-sota/methods/contrastive.py b/ssl-sota/methods/contrastive.py new file mode 100644 index 0000000000000000000000000000000000000000..043afe829d06377e955ceb16f29040f100d3f3db --- /dev/null +++ b/ssl-sota/methods/contrastive.py @@ -0,0 +1,46 @@ +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +from .base import BaseMethod + + +def contrastive_loss(x0, x1, tau, norm): + # https://github.com/google-research/simclr/blob/master/objective.py + bsize = x0.shape[0] + target = torch.arange(bsize).cuda() + eye_mask = torch.eye(bsize).cuda() * 1e9 + if norm: + x0 = F.normalize(x0, p=2, dim=1) + x1 = F.normalize(x1, p=2, dim=1) + logits00 = x0 @ x0.t() / tau - eye_mask + logits11 = x1 @ x1.t() / tau - eye_mask + logits01 = x0 @ x1.t() / tau + logits10 = x1 @ x0.t() / tau + return ( + F.cross_entropy(torch.cat([logits01, logits00], dim=1), target) + + F.cross_entropy(torch.cat([logits10, logits11], dim=1), target) + ) / 2 + + +class Contrastive(BaseMethod): + """ implements contrastive loss https://arxiv.org/abs/2002.05709 """ + + def __init__(self, cfg): + """ init additional BN used after head """ + super().__init__(cfg) + self.bn_last = nn.BatchNorm1d(cfg.emb) + self.loss_f = partial(contrastive_loss, tau=cfg.tau, norm=cfg.norm) + + def forward(self, samples): + bs = len(samples[0]) + h = [self.model(x.cuda(non_blocking=True)) for x in samples] + h = self.bn_last(self.head(torch.cat(h))) + loss = 0 + for i in range(len(samples) - 1): + for j in range(i + 1, len(samples)): + x0 = h[i * bs : (i + 1) * bs] + x1 = h[j * bs : (j + 1) * bs] + loss += self.loss_f(x0, x1) + loss /= self.num_pairs + return loss diff --git a/ssl-sota/methods/norm_mse.py b/ssl-sota/methods/norm_mse.py new file mode 100644 index 0000000000000000000000000000000000000000..fa303b1f0b2d6fbef440664bf89a373bafe65f60 --- /dev/null +++ b/ssl-sota/methods/norm_mse.py @@ -0,0 +1,7 @@ +import torch.nn.functional as F + + +def norm_mse_loss(x0, x1): + x0 = F.normalize(x0) + x1 = F.normalize(x1) + return 2 - 2 * (x0 * x1).sum(dim=-1).mean() diff --git a/ssl-sota/methods/w_mse.py b/ssl-sota/methods/w_mse.py new file mode 100644 index 0000000000000000000000000000000000000000..a7312766e66b1e8f1bc7e1c59fe71948b7d37bdb --- /dev/null +++ b/ssl-sota/methods/w_mse.py @@ -0,0 +1,36 @@ +import torch +import torch.nn.functional as F +from .whitening import Whitening2d +from .base import BaseMethod +from .norm_mse import norm_mse_loss + + +class WMSE(BaseMethod): + """ implements W-MSE loss """ + + def __init__(self, cfg): + """ init whitening transform """ + super().__init__(cfg) + self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False) + self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss + self.w_iter = cfg.w_iter + self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size + + def forward(self, samples): + bs = len(samples[0]) + h = [self.model(x.cuda(non_blocking=True)) for x in samples] + h = self.head(torch.cat(h)) + loss = 0 + for _ in range(self.w_iter): + z = torch.empty_like(h) + perm = torch.randperm(bs).view(-1, self.w_size) + for idx in perm: + for i in range(len(samples)): + z[idx + i * bs] = self.whitening(h[idx + i * bs]) + for i in range(len(samples) - 1): + for j in range(i + 1, len(samples)): + x0 = z[i * bs : (i + 1) * bs] + x1 = z[j * bs : (j + 1) * bs] + loss += self.loss_f(x0, x1) + loss /= self.w_iter * self.num_pairs + return loss diff --git a/ssl-sota/methods/whitening.py b/ssl-sota/methods/whitening.py new file mode 100644 index 0000000000000000000000000000000000000000..20e965310af8ccf3708bf2c78e3fe4ab2e14b994 --- /dev/null +++ b/ssl-sota/methods/whitening.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from torch.nn.functional import conv2d + + +class Whitening2d(nn.Module): + def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0): + super(Whitening2d, self).__init__() + self.num_features = num_features + self.momentum = momentum + self.track_running_stats = track_running_stats + self.eps = eps + + if self.track_running_stats: + self.register_buffer( + "running_mean", torch.zeros([1, self.num_features, 1, 1]) + ) + self.register_buffer("running_variance", torch.eye(self.num_features)) + + def forward(self, x): + x = x.unsqueeze(2).unsqueeze(3) + m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) + if not self.training and self.track_running_stats: # for inference + m = self.running_mean + xn = x - m + + T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) + f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) + + eye = torch.eye(self.num_features).type(f_cov.type()) + + if not self.training and self.track_running_stats: # for inference + f_cov = self.running_variance + + f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye + + inv_sqrt = torch.linalg.solve_triangular( + torch.linalg.cholesky(f_cov_shrinked), + eye, + upper=False + ) + + inv_sqrt = inv_sqrt.contiguous().view( + self.num_features, self.num_features, 1, 1 + ) + + decorrelated = conv2d(xn, inv_sqrt) + + if self.training and self.track_running_stats: + self.running_mean = torch.add( + self.momentum * m.detach(), + (1 - self.momentum) * self.running_mean, + out=self.running_mean, + ) + self.running_variance = torch.add( + self.momentum * f_cov.detach(), + (1 - self.momentum) * self.running_variance, + out=self.running_variance, + ) + + return decorrelated.squeeze(2).squeeze(2) + + def extra_repr(self): + return "features={}, eps={}, momentum={}".format( + self.num_features, self.eps, self.momentum + ) diff --git a/ssl-sota/model.py b/ssl-sota/model.py new file mode 100644 index 0000000000000000000000000000000000000000..184e70dac2a325b7f65e2e77f42ca02b752a9e01 --- /dev/null +++ b/ssl-sota/model.py @@ -0,0 +1,29 @@ +import torch.nn as nn +from torchvision import models + + +def get_head(out_size, cfg): + """ creates projection head g() from config """ + x = [] + in_size = out_size + for _ in range(cfg.head_layers - 1): + x.append(nn.Linear(in_size, cfg.head_size)) + if cfg.add_bn: + x.append(nn.BatchNorm1d(cfg.head_size)) + x.append(nn.ReLU()) + in_size = cfg.head_size + x.append(nn.Linear(in_size, cfg.emb)) + return nn.Sequential(*x) + + +def get_model(arch, dataset): + """ creates encoder E() by name and modifies it for dataset """ + model = getattr(models, arch)(pretrained=False) + if dataset != "imagenet": + model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + if dataset == "cifar10" or dataset == "cifar100": + model.maxpool = nn.Identity() + out_size = model.fc.in_features + model.fc = nn.Identity() + + return nn.DataParallel(model), out_size diff --git a/ssl-sota/scripts-pretrain-resnet50/cifar10.sh b/ssl-sota/scripts-pretrain-resnet50/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f228d2fe76c49ddc349a1c61489d7bbdb1e5d69 --- /dev/null +++ b/ssl-sota/scripts-pretrain-resnet50/cifar10.sh @@ -0,0 +1,23 @@ +#!/bin/bash +gpu=7 +dataset=cifar10 +arch=resnet18 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 + +lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) +lmbda=0.0078125 +echo ${lmbda} + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain-resnet50/cifar100.sh b/ssl-sota/scripts-pretrain-resnet50/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4d167c8dc87ecde663692d63bbd9fdbda324c82 --- /dev/null +++ b/ssl-sota/scripts-pretrain-resnet50/cifar100.sh @@ -0,0 +1,23 @@ +#!/bin/bash +gpu=9 +dataset=cifar100 +arch=resnet50 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 + +lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) #0.0078125 +lmbda=0.0078125 +echo ${lmbda} + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain-resnet50/imagenet.sh b/ssl-sota/scripts-pretrain-resnet50/imagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf95107598978e96047aea309944ae36c7c99542 --- /dev/null +++ b/ssl-sota/scripts-pretrain-resnet50/imagenet.sh @@ -0,0 +1,15 @@ +#!/bin/bash +is_mixup=true +batch_size=1024 #128/gpu works +lr_w=0.2 #0.2 +lr_b=0.0048 #0.0048 +lambda_mixup=0.004 + + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main_imagenet.py /mnt/store/wbandar1/imagenet1k/ --is_mixup ${is_mixup} --batch-size ${batch_size} --learning-rate-weights ${lr_w} --learning-rate-biases ${lr_b} --lambda_mixup ${lambda_mixup}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain-resnet50/stl10.sh b/ssl-sota/scripts-pretrain-resnet50/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..9840f87429d3bdb3de8884c9af98cdf6236efb15 --- /dev/null +++ b/ssl-sota/scripts-pretrain-resnet50/stl10.sh @@ -0,0 +1,23 @@ +#!/bin/bash +gpu=2 +dataset=stl10 +arch=resnet18 +feature_dim=1024 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=4.0 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625pochs=2000 + +lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) +# lmbda=0.0078125 # found that this works better +echo ${lmbda} + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain-resnet50/tinyimagenet.sh b/ssl-sota/scripts-pretrain-resnet50/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..08b27d3b7cfbbc5f9a6862ed3c6329c02aab614d --- /dev/null +++ b/ssl-sota/scripts-pretrain-resnet50/tinyimagenet.sh @@ -0,0 +1,23 @@ +#!/bin/bash +gpu=1 +dataset=tiny_imagenet +arch=resnet18 +feature_dim=2048 +is_mixup=true # true, false +batch_size=256 +epochs=2000 +lr=0.01 +lr_shed=cosine #"step", "cosine" # step, cosine +mixup_loss_scale=4 # scale w.r.t. lambda: 0.0078125 * 5 = 0.0390625 + +lmbda=$(echo "scale=7; 1 / ${feature_dim}" | bc) +# lmbda=0.0078125 # found out that fraction works fine for tiny_imagenet +echo ${lmbda} + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python main.py --lmbda ${lmbda} --corr_zero --batch_size ${batch_size} --feature_dim ${feature_dim} --dataset ${dataset} --is_mixup ${is_mixup} --mixup_loss_scale ${mixup_loss_scale} --epochs ${epochs} --arch ${arch} --gpu ${gpu} --lr_shed ${lr_shed} --lr ${lr}^M" +screen -S "$session_name" -X detach \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/byol/cifar10.sh b/ssl-sota/scripts-pretrain/byol/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..4dec6ce237f5729d06d7efa0d6d283244070e457 --- /dev/null +++ b/ssl-sota/scripts-pretrain/byol/cifar10.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=9 +dataset=cifar10 +method=byol +model=resnet50 +eval_every=5 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/byol/cifar100.sh b/ssl-sota/scripts-pretrain/byol/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b98ada9476e0328d59d3d84d9f93c76a0ca9a92 --- /dev/null +++ b/ssl-sota/scripts-pretrain/byol/cifar100.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=9 +dataset=cifar100 +method=byol +model=resnet50 +eval_every=5 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/byol/stl10.sh b/ssl-sota/scripts-pretrain/byol/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..df7e1f163a8a4c932bf74d8d70902ad6c099e23a --- /dev/null +++ b/ssl-sota/scripts-pretrain/byol/stl10.sh @@ -0,0 +1,14 @@ +#!/bin/bash +gpu=9 +dataset=stl10 +method=byol +model=resnet50 +eval_every=5 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/byol/tinyimagenet.sh b/ssl-sota/scripts-pretrain/byol/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..e2de18464f2c6bc560f594a833808bf0158d81fd --- /dev/null +++ b/ssl-sota/scripts-pretrain/byol/tinyimagenet.sh @@ -0,0 +1,15 @@ +#!/bin/bash +gpu=9 +dataset=tinyimagenet +method=byol +model=resnet50 +eval_every=5 +batch_size=320 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model} --eval_every ${eval_every} --bs ${batch_size}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/contrastive/cifar10.sh b/ssl-sota/scripts-pretrain/contrastive/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..80b74ab494ab8795e30c54f461d0a31587f331d1 --- /dev/null +++ b/ssl-sota/scripts-pretrain/contrastive/cifar10.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=cifar10 +method=contrastive +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/contrastive/cifar100.sh b/ssl-sota/scripts-pretrain/contrastive/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..1666824e99ccbeaf6f41ae454ccf5e3611829561 --- /dev/null +++ b/ssl-sota/scripts-pretrain/contrastive/cifar100.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=cifar100 +method=contrastive +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/contrastive/stl10.sh b/ssl-sota/scripts-pretrain/contrastive/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..3dab325f3d73fdca698469250b5a07f2d160295c --- /dev/null +++ b/ssl-sota/scripts-pretrain/contrastive/stl10.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=stl10 +method=contrastive +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/contrastive/tinyimagenet.sh b/ssl-sota/scripts-pretrain/contrastive/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..492e1d8a0011d98a5bc19ea24895965b118e2792 --- /dev/null +++ b/ssl-sota/scripts-pretrain/contrastive/tinyimagenet.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=tinyimagenet +method=contrastive +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --emb 64 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/w-mse/cifar10.sh b/ssl-sota/scripts-pretrain/w-mse/cifar10.sh new file mode 100644 index 0000000000000000000000000000000000000000..1569e03efa58933dd4890765d8f392dda64c4a0d --- /dev/null +++ b/ssl-sota/scripts-pretrain/w-mse/cifar10.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=cifar10 +method=w_mse +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 3e-3 --bs 256 --emb 64 --w_size 128 --w_eps 10e-3 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/w-mse/cifar100.sh b/ssl-sota/scripts-pretrain/w-mse/cifar100.sh new file mode 100644 index 0000000000000000000000000000000000000000..871230238774dbeb3674f9d1a2be102cc67f4fb3 --- /dev/null +++ b/ssl-sota/scripts-pretrain/w-mse/cifar100.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=cifar100 +method=w_mse +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --lr 3e-3 --bs 256 --emb 64 --w_size 128 --w_eps 10e-3 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/w-mse/stl10.sh b/ssl-sota/scripts-pretrain/w-mse/stl10.sh new file mode 100644 index 0000000000000000000000000000000000000000..5e72a0bd82cd38fcf336cbf0de5217db2f1e353d --- /dev/null +++ b/ssl-sota/scripts-pretrain/w-mse/stl10.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=stl10 +method=w_mse +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 2e-3 --bs 256 --emb 128 --w_size 256 --w_eps 10e-3 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/scripts-pretrain/w-mse/tinyimagenet.sh b/ssl-sota/scripts-pretrain/w-mse/tinyimagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..5240f85622738d3caa9cc6ed32e8a10d7e2727d4 --- /dev/null +++ b/ssl-sota/scripts-pretrain/w-mse/tinyimagenet.sh @@ -0,0 +1,13 @@ +#!/bin/bash +gpu=9 +dataset=tinyimagenet +method=w_mse +model=resnet50 + +timestamp=$(date +"%Y%m%d%H%M%S") +session_name="python_session_$timestamp" +echo ${session_name} +screen -dmS "$session_name" +screen -S "$session_name" -X stuff "conda activate ssl-aug^M" +screen -S "$session_name" -X stuff "CUDA_VISIBLE_DEVICES=${gpu} python -m train --dataset ${dataset} --epoch 1000 --lr 2e-3 --bs 256 --emb 128 --w_size 256 --w_eps 10e-3 --method ${method} --arch ${model}^M" +screen -S "$session_name" -X detachs \ No newline at end of file diff --git a/ssl-sota/test.py b/ssl-sota/test.py new file mode 100644 index 0000000000000000000000000000000000000000..a9330a5cc95f64aeb60d89a3b5d77e898d4228f5 --- /dev/null +++ b/ssl-sota/test.py @@ -0,0 +1,39 @@ +import torch +from datasets import get_ds +from cfg import get_cfg +from methods import get_method + +from eval.sgd import eval_sgd +from eval.knn import eval_knn +from eval.lbfgs import eval_lbfgs +from eval.get_data import get_data + + +if __name__ == "__main__": + cfg = get_cfg() + + model_full = get_method(cfg.method)(cfg) + model_full.cuda().eval() + if cfg.fname is None: + print("evaluating random model") + else: + model_full.load_state_dict(torch.load(cfg.fname)) + + ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers) + device = "cpu" if cfg.clf == "lbfgs" else "cuda" + if cfg.eval_head: + model = lambda x: model_full.head(model_full.model(x)) + out_size = cfg.emb + else: + model = model_full.model + out_size = model_full.out_size + x_train, y_train = get_data(model, ds.clf, out_size, device) + x_test, y_test = get_data(model, ds.test, out_size, device) + + if cfg.clf == "sgd": + acc = eval_sgd(x_train, y_train, x_test, y_test) + if cfg.clf == "knn": + acc = eval_knn(x_train, y_train, x_test, y_test) + elif cfg.clf == "lbfgs": + acc = eval_lbfgs(x_train, y_train, x_test, y_test) + print(acc) diff --git a/ssl-sota/tf2/README.md b/ssl-sota/tf2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5a326b3a20289bc051b79fefd79b8762b86b02ab --- /dev/null +++ b/ssl-sota/tf2/README.md @@ -0,0 +1,8 @@ +`w_mse_loss()` from `whitening.py` is W-MSE loss implementation for TensorFlow 2, +it can be used with other popular implementations, e.g. [SimCLRv2](https://github.com/google-research/simclr/tree/master/tf2). + + +Method uses global flags mechanism as in SimCLRv2: +- `FLAGS.num_samples` - number of samples (d) generated from each image +- `FLAGS.train_batch_size` +- `FLAGS.proj_out_dim` diff --git a/ssl-sota/tf2/whitening.py b/ssl-sota/tf2/whitening.py new file mode 100644 index 0000000000000000000000000000000000000000..e14650a7de184732326e2cd1df678905ef32ec64 --- /dev/null +++ b/ssl-sota/tf2/whitening.py @@ -0,0 +1,45 @@ +import tensorflow.compat.v2 as tf +from absl import flags + +FLAGS = flags.FLAGS + + +class Whitening1D(tf.keras.layers.Layer): + def __init__(self, eps=0, **kwargs): + super(Whitening1D, self).__init__(**kwargs) + self.eps = eps + + def call(self, x): + bs, c = x.shape + x_t = tf.transpose(x, (1, 0)) + m = tf.reduce_mean(x_t, axis=1, keepdims=True) + f = x_t - m + ff_apr = tf.matmul(f, f, transpose_b=True) / (tf.cast(bs, tf.float32) - 1.0) + ff_apr_shrinked = (1 - self.eps) * ff_apr + tf.eye(c) * self.eps + sqrt = tf.linalg.cholesky(ff_apr_shrinked) + inv_sqrt = tf.linalg.triangular_solve(sqrt, tf.eye(c)) + f_hat = tf.matmul(inv_sqrt, f) + decorelated = tf.transpose(f_hat, (1, 0)) + return decorelated + + +def w_mse_loss(x): + """ input x shape = (batch size * num_samples, proj_out_dim) """ + + w = Whitening1D() + num_samples = FLAGS.num_samples + num_slice = num_samples * FLAGS.train_batch_size // (2 * FLAGS.proj_out_dim) + x_split = tf.split(x, num_slice, 0) + for i in range(num_slice): + x_split[i] = w(x_split[i]) + x = tf.concat(x_split, 0) + x = tf.math.l2_normalize(x, -1) + + x_split = tf.split(x, num_samples, 0) + loss = 0 + for i in range(num_samples - 1): + for j in range(i + 1, num_samples): + v = x_split[i] * x_split[j] + loss += 2 - 2 * tf.reduce_mean(tf.reduce_sum(v, -1)) + loss /= num_samples * (num_samples - 1) // 2 + return loss diff --git a/ssl-sota/train.py b/ssl-sota/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f86af7eeb12ae9e6de85b3b7e0853bf4ab858a50 --- /dev/null +++ b/ssl-sota/train.py @@ -0,0 +1,93 @@ +from tqdm import trange, tqdm +import numpy as np +import wandb +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts +import torch.backends.cudnn as cudnn +import os + +from cfg import get_cfg +from datasets import get_ds +from methods import get_method + + +def get_scheduler(optimizer, cfg): + if cfg.lr_step == "cos": + return CosineAnnealingWarmRestarts( + optimizer, + T_0=cfg.epoch if cfg.T0 is None else cfg.T0, + T_mult=cfg.Tmult, + eta_min=cfg.eta_min, + ) + elif cfg.lr_step == "step": + m = [cfg.epoch - a for a in cfg.drop] + return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma) + else: + return None + + +if __name__ == "__main__": + cfg = get_cfg() + wandb.init(project=f"ssl-sota-{cfg.method}-{cfg.dataset}", config=cfg, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') + run_id = wandb.run.id + # if not os.path.exists('../results'): + # os.mkdir('../results') + run_id_dir = os.path.join('/mnt/store/wbandar1/projects/ssl-aug-artifacts/', run_id) + if not os.path.exists(run_id_dir): + print('Creating directory {}'.format(run_id_dir)) + os.mkdir(run_id_dir) + + ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers) + model = get_method(cfg.method)(cfg) + model.cuda().train() + if cfg.fname is not None: + model.load_state_dict(torch.load(cfg.fname)) + + optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.adam_l2) + scheduler = get_scheduler(optimizer, cfg) + + eval_every = cfg.eval_every + lr_warmup = 0 if cfg.lr_warmup else 500 + cudnn.benchmark = True + + for ep in trange(cfg.epoch, position=0): + loss_ep = [] + iters = len(ds.train) + for n_iter, (samples, _) in enumerate(tqdm(ds.train, position=1)): + if lr_warmup < 500: + lr_scale = (lr_warmup + 1) / 500 + for pg in optimizer.param_groups: + pg["lr"] = cfg.lr * lr_scale + lr_warmup += 1 + + optimizer.zero_grad() + loss = model(samples) + loss.backward() + optimizer.step() + loss_ep.append(loss.item()) + model.step(ep / cfg.epoch) + if cfg.lr_step == "cos" and lr_warmup >= 500: + scheduler.step(ep + n_iter / iters) + + if cfg.lr_step == "step": + scheduler.step() + + if len(cfg.drop) and ep == (cfg.epoch - cfg.drop[0]): + eval_every = cfg.eval_every_drop + + if (ep + 1) % eval_every == 0: + # acc_knn, acc = model.get_acc(ds.clf, ds.test) + # wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False) + acc_knn = model.get_acc_knn(ds.clf, ds.test) + wandb.log({"acc_knn": acc_knn}, commit=False) + + if (ep + 1) % 100 == 0: + fname = f"/mnt/store/wbandar1/projects/ssl-aug-artifacts/{run_id}/{cfg.method}_{cfg.dataset}_{ep}.pt" + torch.save(model.state_dict(), fname) + wandb.log({"loss": np.mean(loss_ep), "ep": ep}) + + acc_knn, acc = model.get_acc(ds.clf, ds.test) + print('Final linear-acc: {}, knn-acc'.format(acc, acc_knn)) + wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False) + wandb.finish() \ No newline at end of file diff --git a/transfer_datasets/README.md b/transfer_datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e7ba755b6947e6c02ec5ccfd929f69cb41deaf83 --- /dev/null +++ b/transfer_datasets/README.md @@ -0,0 +1,19 @@ +## Datasets +Download the trasfer learning dataset from the below links, + +* **ImageNet-1k**: https://image-net.org + +* **DTD**: https://www.robots.ox.ac.uk/~vgg/data/dtd/ + +* **CUBirds**: http://www.vision.caltech.edu/datasets/cub_200_2011/ (no longer working: http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) + +* **VGG Flower**: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ + +* **Traffic Signs**: https://benchmark.ini.rub.de/gtsdb_dataset.html + +* **Aircraft**: https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ + +and put them under their respective paths, e.g., '../data/DTD'. + +## Training and linear evaluation +All training and linear evaluation commands are provided in `main.sh` in the folder corresponding to the model. \ No newline at end of file diff --git a/transfer_datasets/__init__.py b/transfer_datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f07629070b6c5a60702ac83f4aa7a7e75bd9b434 --- /dev/null +++ b/transfer_datasets/__init__.py @@ -0,0 +1,16 @@ +from .aircraft import Aircraft +from .cu_birds import CUBirds +from .dtd import DTD +from .fashionmnist import FashionMNIST +from .mnist import MNIST +from .traffic_sign import TrafficSign +from .vgg_flower import VGGFlower + +TRANSFER_DATASET = { + 'aircraft': Aircraft, + 'cu_birds': CUBirds, + 'dtd': DTD, + 'fashionmnist': FashionMNIST, + 'mnist': MNIST, + 'traffic_sign': TrafficSign, + 'vgg_flower': VGGFlower} \ No newline at end of file diff --git a/transfer_datasets/aircraft.py b/transfer_datasets/aircraft.py new file mode 100644 index 0000000000000000000000000000000000000000..f8dee6fa617fd6b73d983c25fb1aaa555636dd82 --- /dev/null +++ b/transfer_datasets/aircraft.py @@ -0,0 +1,74 @@ +import os +import numpy as np +from PIL import Image +from os.path import join +from collections import defaultdict +import torch.utils.data as data + +DATA_ROOTS = 'data/Aircraft' + +# url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' +# wget http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz +# python +# from torchvision.datasets.utils import extract_archive +# extract_archive("fgvc-aircraft-2013b.tar.gz") +# Download and preprocess: https://github.com/lvyilin/pytorch-fgvc-dataset/blob/master/aircraft.py + +# class_types = ('variant', 'family', 'manufacturer') +# splits = ('train', 'val', 'trainval', 'test') +# img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images') + +class Aircraft(data.Dataset): + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + self.root = root + self.train = train + self.image_transforms = image_transforms + paths, bboxes, labels = self.load_images() + self.paths = paths + self.bboxes = bboxes + self.labels = labels + + def load_images(self): + split = 'trainval' if self.train else 'test' + variant_path = os.path.join(self.root, 'data', 'images_variant_%s.txt'%split) + with open(variant_path, 'r') as f: + names_to_variants = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] + names_to_variants = dict(names_to_variants) + variants_to_names = defaultdict(list) + for name, variant in names_to_variants.items(): + variants_to_names[variant].append(name) + variants = sorted(list(set(variants_to_names.keys()))) + + names_to_bboxes = self.get_bounding_boxes() + split_files, split_labels, split_bboxes = [], [], [] + for variant_id, variant in enumerate(variants): + class_files = [join(self.root, 'data', 'images', '%s.jpg'%filename) for filename in sorted(variants_to_names[variant])] + bboxes = [names_to_bboxes[name] for name in sorted(variants_to_names[variant])] + labels = list([variant_id] * len(class_files)) + split_files += class_files + split_labels += labels + split_bboxes += bboxes + return split_files, split_bboxes, split_labels + + def get_bounding_boxes(self): + bboxes_path = os.path.join(self.root, 'data', 'images_box.txt') + with open(bboxes_path, 'r') as f: + names_to_bboxes = [line.split('\n')[0].split(' ') for line in f.readlines()] + names_to_bboxes = dict((name, list(map(int, (xmin, ymin, xmax, ymax)))) for name, xmin, ymin, xmax, ymax in names_to_bboxes) + return names_to_bboxes + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + bbox = tuple(self.bboxes[index]) + label = self.labels[index] + + image = Image.open(path).convert(mode='RGB') + image = image.crop(bbox) + + if self.image_transforms: + image = self.image_transforms(image) + return image, label \ No newline at end of file diff --git a/transfer_datasets/cu_birds.py b/transfer_datasets/cu_birds.py new file mode 100644 index 0000000000000000000000000000000000000000..353ae453dee70051ef832d6ec51f67c2e99b7353 --- /dev/null +++ b/transfer_datasets/cu_birds.py @@ -0,0 +1,61 @@ +import os +import numpy as np +from PIL import Image +import torch.utils.data as data + +DATA_ROOTS = 'data/CUBirds' + +# wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz +# tar -xvzf CUB_200_2011.tgz + +class CUBirds(data.Dataset): + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + self.root = root + self.train = train + self.image_transforms = image_transforms + paths, labels = self.load_images() + self.paths, self.labels = paths, labels + + def load_images(self): + image_info_path = os.path.join(self.root, 'images.txt') + with open(image_info_path, 'r') as f: + image_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] + image_info = dict(image_info) + + # load image to label information + label_info_path = os.path.join(self.root, 'image_class_labels.txt') + with open(label_info_path, 'r') as f: + label_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] + label_info = dict(label_info) + + # load train test split + train_test_info_path = os.path.join(self.root, 'train_test_split.txt') + with open(train_test_info_path, 'r') as f: + train_test_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] + train_test_info = dict(train_test_info) + + all_paths, all_labels = [], [] + for index, image_path in image_info.items(): + label = label_info[index] + split = int(train_test_info[index]) + if self.train: + if split == 1: + all_paths.append(image_path) + all_labels.append(label) + else: + if split == 0: + all_paths.append(image_path) + all_labels.append(label) + return all_paths, all_labels + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = os.path.join(self.root, 'images', self.paths[index]) + label = int(self.labels[index]) - 1 + image = Image.open(path).convert(mode='RGB') + if self.image_transforms: + image = self.image_transforms(image) + return image, label \ No newline at end of file diff --git a/transfer_datasets/dtd.py b/transfer_datasets/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..24c52439d1bc14643be43ce56490d840856bf76c --- /dev/null +++ b/transfer_datasets/dtd.py @@ -0,0 +1,69 @@ +import os +import copy +import numpy as np +from PIL import Image +from os.path import join +from itertools import chain +from collections import defaultdict + +import torch +import torch.utils.data as data +from torchvision import transforms + +DATA_ROOTS = 'data/DTD' + +# wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz +# tar -xvzf dtd-r1.0.1.tar.gz + +class DTD(data.Dataset): + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + self.root = root + self.train = train + self.image_transforms = image_transforms + paths, labels = self.load_images() + self.paths, self.labels = paths, labels + + def load_images(self): + if self.train: + train_info_path = os.path.join(self.root, 'labels', 'train1.txt') + with open(train_info_path, 'r') as f: + train_info = [line.split('\n')[0] for line in f.readlines()] + + val_info_path = os.path.join(self.root, 'labels', 'val1.txt') + with open(val_info_path, 'r') as f: + val_info = [line.split('\n')[0] for line in f.readlines()] + split_info = train_info + val_info + + else: + test_info_path = os.path.join(self.root, 'labels', 'test1.txt') + with open(test_info_path, 'r') as f: + split_info = [line.split('\n')[0] for line in f.readlines()] + + # pull out categoires from paths + categories = [] + for row in split_info: + image_path = row + category = image_path.split('/')[0] + categories.append(category) + categories = sorted(list(set(categories))) + + all_paths, all_labels = [], [] + for row in split_info: + image_path = row + category = image_path.split('/')[0] + label = categories.index(category) + all_paths.append(join(self.root, 'images', image_path)) + all_labels.append(label) + return all_paths, all_labels + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + label = self.labels[index] + image = Image.open(path).convert(mode='RGB') + if self.image_transforms: + image = self.image_transforms(image) + return image, label \ No newline at end of file diff --git a/transfer_datasets/fashionmnist.py b/transfer_datasets/fashionmnist.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1455ef43dbdbbb388240d92f7a121bd06f7d75 --- /dev/null +++ b/transfer_datasets/fashionmnist.py @@ -0,0 +1,28 @@ +import os +import copy +from PIL import Image +import numpy as np + +import torch +import torch.utils.data as data +from torchvision import transforms, datasets + +DATA_ROOTS = 'data' + +class FashionMNIST(data.Dataset): + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + if not os.path.isdir(root): + os.makedirs(root) + self.image_transforms = image_transforms + self.dataset = datasets.mnist.FashionMNIST(root, train=train, download=True) + + def __getitem__(self, index): + img, target = self.dataset.data[index], int(self.dataset.targets[index]) + img = Image.fromarray(img.numpy(), mode='L').convert('RGB') + if self.image_transforms is not None: + img = self.image_transforms(img) + return img, target + + def __len__(self): + return len(self.dataset) \ No newline at end of file diff --git a/transfer_datasets/mnist.py b/transfer_datasets/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab88a5ef8c3d82d7115feb4d4f74402cb7287d2 --- /dev/null +++ b/transfer_datasets/mnist.py @@ -0,0 +1,28 @@ +import os +import copy +from PIL import Image +import numpy as np + +import torch +import torch.utils.data as data +from torchvision import transforms, datasets + +DATA_ROOTS = 'data' + +class MNIST(data.Dataset): + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + if not os.path.isdir(root): + os.makedirs(root) + self.image_transforms = image_transforms + self.dataset = datasets.mnist.MNIST(root, train=train, download=True) + + def __getitem__(self, index): + img, target = self.dataset.data[index], int(self.dataset.targets[index]) + img = Image.fromarray(img.numpy(), mode='L').convert('RGB') + if self.image_transforms is not None: + img = self.image_transforms(img) + return img, target + + def __len__(self): + return len(self.dataset) \ No newline at end of file diff --git a/transfer_datasets/traffic_sign.py b/transfer_datasets/traffic_sign.py new file mode 100644 index 0000000000000000000000000000000000000000..708e989a8e2abe3f01c46ec273e2dba0c461734f --- /dev/null +++ b/transfer_datasets/traffic_sign.py @@ -0,0 +1,65 @@ +import os +import copy +import json +import operator +import numpy as np +from PIL import Image +from glob import glob +from os.path import join +from itertools import chain +from scipy.io import loadmat +from collections import defaultdict + +import torch +import torch.utils.data as data +from torchvision import transforms + +DATA_ROOTS = 'data/TrafficSign' + +# wget https://sid.erda.dk/public/archives/ff17dc924eba88d5d01a807357d6614c/FullIJCNN2013.zip +# unzip FullIJCNN2013.zip + +class TrafficSign(data.Dataset): + NUM_CLASSES = 43 + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + self.root = root + self.train = train + self.image_transforms = image_transforms + paths, labels = self.load_images() + self.paths, self.labels = paths, labels + + def load_images(self): + split = 'Final_Training' + rs = np.random.RandomState(42) + all_filepaths, all_labels = [], [] + for class_i in range(self.NUM_CLASSES): + # class_dir_i = join(self.root, split, 'Images', '{:05d}'.format(class_i)) + class_dir_i = join(self.root, '{:02d}'.format(class_i)) + image_paths = glob(join(class_dir_i, "*.ppm")) + # train test splitting + image_paths = np.array(image_paths) + num = len(image_paths) + indexer = np.arange(num) + rs.shuffle(indexer) + image_paths = image_paths[indexer].tolist() + if self.train: + image_paths = image_paths[:int(0.8 * num)] + else: + image_paths = image_paths[int(0.8 * num):] + labels = [class_i] * len(image_paths) + all_filepaths.extend(image_paths) + all_labels.extend(labels) + + return all_filepaths, all_labels + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + label = self.labels[index] + image = Image.open(path).convert(mode='RGB') + if self.image_transforms: + image = self.image_transforms(image) + return image, label \ No newline at end of file diff --git a/transfer_datasets/vgg_flower.py b/transfer_datasets/vgg_flower.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe055a354ab094b943f0ac46494dd9b15d4267d --- /dev/null +++ b/transfer_datasets/vgg_flower.py @@ -0,0 +1,73 @@ +import os +import copy +import json +import operator +import numpy as np +from PIL import Image +from os.path import join +from itertools import chain +from scipy.io import loadmat +from collections import defaultdict + +import torch +import torch.utils.data as data +from torchvision import transforms + +DATA_ROOTS = 'data/VGGFlower' + +# wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz +# tar -xvzf 102flowers.tgz +# rename file to VGGFlower +# cd VGGFlower +# wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat + + +class VGGFlower(data.Dataset): + def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): + super().__init__() + self.root = root + self.train = train + self.image_transforms = image_transforms + paths, labels = self.load_images() + self.paths, self.labels = paths, labels + + def load_images(self): + rs = np.random.RandomState(42) + imagelabels_path = os.path.join(self.root, 'imagelabels.mat') + with open(imagelabels_path, 'rb') as f: + labels = loadmat(f)['labels'][0] + + all_filepaths = defaultdict(list) + for i, label in enumerate(labels): + # all_filepaths[label].append(os.path.join(self.root, 'jpg', 'image_{:05d}.jpg'.format(i+1))) + all_filepaths[label].append(os.path.join(self.root, 'image_{:05d}.jpg'.format(i+1))) + # train test split + split_filepaths, split_labels = [], [] + for label, paths in all_filepaths.items(): + num = len(paths) + paths = np.array(paths) + indexer = np.arange(num) + rs.shuffle(indexer) + paths = paths[indexer].tolist() + + if self.train: + paths = paths[:int(0.8 * num)] + else: + paths = paths[int(0.8 * num):] + + labels = [label] * len(paths) + split_filepaths.extend(paths) + split_labels.extend(labels) + + return split_filepaths, split_labels + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + label = int(self.labels[index]) - 1 + image = Image.open(path).convert(mode='RGB') + if self.image_transforms: + image = self.image_transforms(image) + return image, label \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32570c63b3c9fd39631390c8c4637303f44872c1 --- /dev/null +++ b/utils.py @@ -0,0 +1,102 @@ +from PIL import Image +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from augmentations.augmentations_cifar import aug_cifar +from augmentations.augmentations_tiny import aug_tiny +from augmentations.augmentations_stl import aug_stl + +# for cifar10 / cifar100 (32x32) +class CifarPairTransform: + def __init__(self, train_transform = True, pair_transform = True): + if train_transform is True: + self.transform = transforms.Compose([ + transforms.RandomResizedCrop(32), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) + self.pair_transform = pair_transform + def __call__(self, x): + if self.pair_transform is True: + y1 = self.transform(x) + y2 = self.transform(x) + return y1, y2 + else: + return self.transform(x) + +# for tiny_imagenet (64x64) +class TinyImageNetPairTransform: + def __init__(self, train_transform = True, pair_transform = True): + if train_transform is True: + self.transform = transforms.Compose([ + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.4, hue=0.1)], + p=0.8 + ), + transforms.RandomGrayscale(p=0.1), + transforms.RandomResizedCrop( + 64, + scale=(0.2, 1.0), + ratio=(0.75, (4 / 3)), + interpolation=Image.BICUBIC, + ), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToTensor(), + transforms.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282)) + ]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282)) + ]) + self.pair_transform = pair_transform + def __call__(self, x): + if self.pair_transform is True: + y1 = self.transform(x) + y2 = self.transform(x) + return y1, y2 + else: + return self.transform(x) + +# for stl10 (96x96) +class StlPairTransform: + def __init__(self, train_transform = True, pair_transform = True): + if train_transform is True: + self.transform = transforms.Compose([ + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.4, hue=0.1)], + p=0.8 + ), + transforms.RandomGrayscale(p=0.1), + transforms.RandomResizedCrop( + 64, + scale=(0.2, 1.0), + ratio=(0.75, (4 / 3)), + interpolation=Image.BICUBIC, + ), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToTensor(), + transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)) + ]) + else: + self.transform = transforms.Compose([ + transforms.Resize(70, interpolation=Image.BICUBIC), + transforms.CenterCrop(64), + transforms.ToTensor(), + transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)) + ]) + self.pair_transform = pair_transform + def __call__(self, x): + if self.pair_transform is True: + y1 = self.transform(x) + y2 = self.transform(x) + return y1, y2 + else: + return self.transform(x) \ No newline at end of file