import argparse |
import os |
import random |
import shutil |
import time |
import warnings |
import torch |
import torch.nn as nn |
import torch.nn.parallel |
import torch.backends.cudnn as cudnn |
import torch.distributed as dist |
import torch.optim |
import torch.multiprocessing as mp |
import torch.utils.data |
import torch.utils.data.distributed |
import torchvision.transforms as transforms |
import torchvision.datasets as datasets |
import torchvision.models as models |
from ViT.ViT import vit_base_patch16_224 as vit |
from robustness_dataset import RobustnessDataset |
from objectnet_dataset import ObjectNetDataset |
model_names = sorted(name for name in models.__dict__ |
if name.islower() and not name.startswith("__") |
and callable(models.__dict__[name])) |
model_names.append("vit") |
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') |
parser.add_argument('--data', metavar='DIR', |
help='path to dataset') |
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', |
help='number of data loading workers (default: 4)') |
parser.add_argument('--epochs', default=150, type=int, metavar='N', |
help='number of total epochs to run') |
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', |
help='manual epoch number (useful on restarts)') |
parser.add_argument('-b', '--batch-size', default=256, type=int, |
metavar='N', |
help='mini-batch size (default: 256), this is the total ' |
'batch size of all GPUs on the current node when ' |
'using Data Parallel or Distributed Data Parallel') |
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, |
metavar='LR', help='initial learning rate', dest='lr') |
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', |
help='momentum') |
parser.add_argument('--wd', '--weight-decay', default=0.05, type=float, |
metavar='W', help='weight decay (default: 1e-4)', |
dest='weight_decay') |
parser.add_argument('-p', '--print-freq', default=10, type=int, |
metavar='N', help='print frequency (default: 10)') |
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', |
help='path to latest checkpoint (default: none)') |
parser.add_argument('--resume', default='', type=str, metavar='PATH', |
help='path to resume checkpoint (default: none)') |
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', |
help='evaluate model on validation set') |
parser.add_argument('--pretrained', dest='pretrained', action='store_true', |
help='use pre-trained model') |
parser.add_argument('--world-size', default=-1, type=int, |
help='number of nodes for distributed training') |
parser.add_argument('--rank', default=-1, type=int, |
help='node rank for distributed training') |
parser.add_argument('--dist-url', default='tcp://', type=str, |
help='url used to set up distributed training') |
parser.add_argument('--dist-backend', default='nccl', type=str, |
help='distributed backend') |
parser.add_argument('--seed', default=None, type=int, |
help='seed for initializing training. ') |
parser.add_argument('--gpu', default=None, type=int, |
help='GPU id to use.') |
parser.add_argument('--multiprocessing-distributed', action='store_true', |
help='Use multi-processing distributed training to launch ' |
'N processes per node, which has N GPUs. This is the ' |
'fastest way to use PyTorch for either single node or ' |
'multi node data parallel training') |
parser.add_argument("--isV2", default=False, action='store_true', |
help='is dataset imagenet V2.') |
parser.add_argument("--isSI", default=False, action='store_true', |
help='is dataset SI-score.') |
parser.add_argument("--isObjectNet", default=False, action='store_true', |
help='is dataset SI-score.') |
def main(): |
args = parser.parse_args() |
if args.seed is not None: |
random.seed(args.seed) |
torch.manual_seed(args.seed) |
cudnn.deterministic = True |
warnings.warn('You have chosen to seed training. ' |
'This will turn on the CUDNN deterministic setting, ' |
'which can slow down your training considerably! ' |
'You may see unexpected behavior when restarting ' |
'from checkpoints.') |
if args.gpu is not None: |
warnings.warn('You have chosen a specific GPU. This will completely ' |
'disable data parallelism.') |
if args.dist_url == "env://" and args.world_size == -1: |
args.world_size = int(os.environ["WORLD_SIZE"]) |
args.distributed = args.world_size > 1 or args.multiprocessing_distributed |
ngpus_per_node = torch.cuda.device_count() |
if args.multiprocessing_distributed: |
args.world_size = ngpus_per_node * args.world_size |
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) |
else: |
main_worker(args.gpu, ngpus_per_node, args) |
def main_worker(gpu, ngpus_per_node, args): |
global best_acc1 |
args.gpu = gpu |
if args.gpu is not None: |
print("Use GPU: {} for training".format(args.gpu)) |
if args.distributed: |
if args.dist_url == "env://" and args.rank == -1: |
args.rank = int(os.environ["RANK"]) |
if args.multiprocessing_distributed: |
args.rank = args.rank * ngpus_per_node + gpu |
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
world_size=args.world_size, rank=args.rank) |
print("=> creating model") |
if args.checkpoint: |
model = vit().cuda() |
checkpoint = torch.load(args.checkpoint) |
model.load_state_dict(checkpoint['state_dict']) |
else: |
model = vit(pretrained=True).cuda() |
print("done") |
if not torch.cuda.is_available(): |
print('using CPU, this will be slow') |
elif args.distributed: |
if args.gpu is not None: |
torch.cuda.set_device(args.gpu) |
model.cuda(args.gpu) |
args.batch_size = int(args.batch_size / ngpus_per_node) |
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
else: |
model.cuda() |
model = torch.nn.parallel.DistributedDataParallel(model) |
elif args.gpu is not None: |
torch.cuda.set_device(args.gpu) |
model = model.cuda(args.gpu) |
else: |
print("start") |
model = torch.nn.DataParallel(model).cuda() |
if args.resume: |
if os.path.isfile(args.resume): |
print("=> loading checkpoint '{}'".format(args.resume)) |
if args.gpu is None: |
checkpoint = torch.load(args.resume) |
else: |
loc = 'cuda:{}'.format(args.gpu) |
checkpoint = torch.load(args.resume, map_location=loc) |
args.start_epoch = checkpoint['epoch'] |
best_acc1 = checkpoint['best_acc1'] |
if args.gpu is not None: |
best_acc1 = best_acc1.to(args.gpu) |
model.load_state_dict(checkpoint['state_dict']) |
print("=> loaded checkpoint '{}' (epoch {})" |
.format(args.resume, checkpoint['epoch'])) |
else: |
print("=> no checkpoint found at '{}'".format(args.resume)) |
cudnn.benchmark = True |
if args.isObjectNet: |
val_dataset = ObjectNetDataset(args.data) |
else: |
val_dataset = RobustnessDataset(args.data, isV2=args.isV2, isSI=args.isSI) |
val_loader = torch.utils.data.DataLoader( |
val_dataset, batch_size=args.batch_size, shuffle=False, |
num_workers=args.workers, pin_memory=True) |
if args.evaluate: |
validate(val_loader, model, args) |
return |
def validate(val_loader, model, args): |
batch_time = AverageMeter('Time', ':6.3f') |
losses = AverageMeter('Loss', ':.4e') |
top1 = AverageMeter('Acc@1', ':6.2f') |
top5 = AverageMeter('Acc@5', ':6.2f') |
progress = ProgressMeter( |
len(val_loader), |
[batch_time, losses, top1, top5], |
prefix='Test: ') |
model.eval() |
with torch.no_grad(): |
end = time.time() |
for i, (images, target) in enumerate(val_loader): |
if args.gpu is not None: |
images = images.cuda(args.gpu, non_blocking=True) |
if torch.cuda.is_available(): |
target = target.cuda(args.gpu, non_blocking=True) |
output = model(images) |
acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
top1.update(acc1[0], images.size(0)) |
top5.update(acc5[0], images.size(0)) |
batch_time.update(time.time() - end) |
end = time.time() |
if i % args.print_freq == 0: |
progress.display(i) |
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' |
.format(top1=top1, top5=top5)) |
return top1.avg |
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): |
torch.save(state, filename) |
if is_best: |
shutil.copyfile(filename, 'model_best.pth.tar') |
class AverageMeter(object): |
"""Computes and stores the average and current value""" |
def __init__(self, name, fmt=':f'): |
self.name = name |
self.fmt = fmt |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.sum = 0 |
self.count = 0 |
def update(self, val, n=1): |
self.val = val |
self.sum += val * n |
self.count += n |
self.avg = self.sum / self.count |
def __str__(self): |
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
return fmtstr.format(**self.__dict__) |
class ProgressMeter(object): |
def __init__(self, num_batches, meters, prefix=""): |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
self.meters = meters |
self.prefix = prefix |
def display(self, batch): |
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
entries += [str(meter) for meter in self.meters] |
print('\t'.join(entries)) |
def _get_batch_fmtstr(self, num_batches): |
num_digits = len(str(num_batches // 1)) |
fmt = '{:' + str(num_digits) + 'd}' |
return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
def adjust_learning_rate(optimizer, epoch, args): |
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" |
lr = args.lr * (0.85 ** (epoch // 2)) |
for param_group in optimizer.param_groups: |
param_group['lr'] = lr |
def accuracy(output, target, topk=(1,)): |
"""Computes the accuracy over the k top predictions for the specified values of k""" |
with torch.no_grad(): |
maxk = max(topk) |
batch_size = target.size(0) |
_, pred = output.topk(maxk, 1, True, True) |
pred = pred.t() |
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
res = [] |
for k in topk: |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
res.append(correct_k.mul_(100.0 / batch_size)) |
return res |
if __name__ == '__main__': |
main() |