import argparse import os import torch import clip import os from tqdm import tqdm import time from utils import ModelWrapper, maybe_dictionarize_batch, cosine_lr from zeroshot import zeroshot_classifier import torch from torchvision import transforms, datasets def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--data-location", type=str, default=os.path.expanduser('~/data'), help="The root directory for the datasets.", ) parser.add_argument( "--model-location", type=str, default=os.path.expanduser('~/ssd/checkpoints/soups'), help="Where to download the models.", ) parser.add_argument( "--batch-size", type=int, default=256, ) parser.add_argument( "--workers", type=int, default=8, ) parser.add_argument( "--epochs", type=int, default=8, ) parser.add_argument( "--warmup-length", type=int, default=500, ) parser.add_argument( "--lr", type=float, default=2e-5, ) parser.add_argument( "--wd", type=float, default=0.1, ) parser.add_argument( "--model", default='ViT-B/32', help='Model to use -- you can try another like ViT-L/14' ) parser.add_argument( "--name", default='finetune_cp', help='Filename for the checkpoints.' ) parser.add_argument( "--timm-aug", action="store_true", default=False, ) parser.add_argument( "--checkpoint_path", default=None, help='Checkpoint path to load the model' ) return parser.parse_args() if __name__ == '__main__': args = parse_arguments() DEVICE = 'cuda' template = [lambda x : f"a photo generated by {x}."] base_model, preprocess = clip.load(args.model, 'cuda', jit=False) train_transforms = transforms.Compose([transforms.RandomRotation(30), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor()]) test_transforms = transforms.Compose([transforms.RandomRotation(30), transforms.RandomResizedCrop(224), transforms.ToTensor()]) train_data = datasets.ImageFolder(args.data_location + '/train', transform=train_transforms) test_data = datasets.ImageFolder(args.data_location + '/test', transform=test_transforms) train_dset = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle = True) test_dset = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers) clf = zeroshot_classifier(base_model, ['humans', 'AI'], template, DEVICE) NUM_CLASSES = 2 feature_dim = base_model.visual.output_dim model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf, checkpoint_path = args.checkpoint_path) for p in model.parameters(): p.data = p.data.float() model = model.cuda() devices = [x for x in range(torch.cuda.device_count())] model = torch.nn.DataParallel(model, device_ids=devices) model_parameters = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd) num_batches = len(train_dset) scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches) loss_fn = torch.nn.CrossEntropyLoss() model_path = os.path.join(args.model_location, f'{args.name}.pt') print('Saving model to', model_path) torch.save(model.module.state_dict(), model_path) last_accuracy = 0.0 for epoch in range(args.epochs): # Train model.train() end = time.time() for i, batch in enumerate(train_dset): step = i + epoch * num_batches scheduler(step) optimizer.zero_grad() batch = maybe_dictionarize_batch(batch) inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE) data_time = time.time() - end logits = model(inputs) loss = loss_fn(logits, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() batch_time = time.time() - end end = time.time() if i % 20 == 0: percent_complete = 100.0 * i / len(train_dset) print( f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_dset)}]\t" f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True ) ## Evaluate test_loader = test_dset model.eval() with torch.no_grad(): print('*'*80) print('Starting eval') correct, count = 0.0, 0.0 pbar = tqdm(test_loader) for batch in pbar: batch = maybe_dictionarize_batch(batch) inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE) logits = model(inputs) loss = loss_fn(logits, labels) pred = logits.argmax(dim=1, keepdim=True) correct += pred.eq(labels.view_as(pred)).sum().item() count += len(logits) pbar.set_description( f"Val loss: {loss.item():.4f} Acc: {100*correct/count:.2f}") top1 = correct / count print(f'Val acc at epoch {epoch}: {100*top1:.2f}') curr_acc = 100*top1 if curr_acc > last_accuracy: print('Current acc: {}, Last acc: {}'.format(curr_acc, last_accuracy)) last_accuracy = curr_acc model_path = os.path.join(args.model_location, f'{args.name}.pt') print('Saving model to', model_path) torch.save(model.module.state_dict(), model_path) else: print('Not saving the model')