File size: 2,437 Bytes
97b6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)

<font size=4><b>Train Wide-ResNet, Shake-Shake and ShakeDrop models on CIFAR-10
and CIFAR-100 dataset with AutoAugment.</b></font>

The CIFAR-10/CIFAR-100 data can be downloaded from:
https://www.cs.toronto.edu/~kriz/cifar.html. Use the Python version instead of the binary version.

The code replicates the results from Tables 1 and 2 on CIFAR-10/100 with the
following models: Wide-ResNet-28-10, Shake-Shake (26 2x32d), Shake-Shake (26
2x96d) and PyramidNet+ShakeDrop.

<b>Related papers:</b>

AutoAugment: Learning Augmentation Policies from Data

https://arxiv.org/abs/1805.09501

Wide Residual Networks

https://arxiv.org/abs/1605.07146

Shake-Shake regularization

https://arxiv.org/abs/1705.07485

ShakeDrop regularization

https://arxiv.org/abs/1802.02375

<b>Settings:</b>

CIFAR-10 Model         | Learning Rate | Weight Decay | Num. Epochs | Batch Size
---------------------- | ------------- | ------------ | ----------- | ----------
Wide-ResNet-28-10      | 0.1           | 5e-4         | 200         | 128
Shake-Shake (26 2x32d) | 0.01          | 1e-3         | 1800        | 128
Shake-Shake (26 2x96d) | 0.01          | 1e-3         | 1800        | 128
PyramidNet + ShakeDrop | 0.05          | 5e-5         | 1800        | 64

<b>Prerequisite:</b>

1.  Install TensorFlow. Be sure to run the code using python2 and not python3.

2.  Download CIFAR-10/CIFAR-100 dataset.

```shell
curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
```

<b>How to run:</b>

```shell
# cd to the your workspace.
# Specify the directory where dataset is located using the data_path flag.
# Note: User can split samples from training set into the eval set by changing train_size and validation_size.

# For example, to train the Wide-ResNet-28-10 model on a GPU.
python train_cifar.py --model_name=wrn \
                      --checkpoint_dir=/tmp/training \
                      --data_path=/tmp/data \
                      --dataset='cifar10' \
                      --use_cpu=0
```

## Contact for Issues

*   Barret Zoph, @barretzoph <barretzoph@google.com>
*   Ekin Dogus Cubuk, <cubuk@google.com>