model_soups / finetune.py
SaraPieri
Models
0f4b1a2
import argparse
import os
import torch
import clip
import os
from tqdm import tqdm
import time
from utils import ModelWrapper, maybe_dictionarize_batch, cosine_lr
from zeroshot import zeroshot_classifier
import torch
from torchvision import transforms, datasets
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-location",
type=str,
default=os.path.expanduser('~/data'),
help="The root directory for the datasets.",
)
parser.add_argument(
"--model-location",
type=str,
default=os.path.expanduser('~/ssd/checkpoints/soups'),
help="Where to download the models.",
)
parser.add_argument(
"--batch-size",
type=int,
default=256,
)
parser.add_argument(
"--workers",
type=int,
default=8,
)
parser.add_argument(
"--epochs",
type=int,
default=8,
)
parser.add_argument(
"--warmup-length",
type=int,
default=500,
)
parser.add_argument(
"--lr",
type=float,
default=2e-5,
)
parser.add_argument(
"--wd",
type=float,
default=0.1,
)
parser.add_argument(
"--model",
default='ViT-B/32',
help='Model to use -- you can try another like ViT-L/14'
)
parser.add_argument(
"--name",
default='finetune_cp',
help='Filename for the checkpoints.'
)
parser.add_argument(
"--timm-aug", action="store_true", default=False,
)
parser.add_argument(
"--checkpoint_path",
default=None,
help='Checkpoint path to load the model'
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
DEVICE = 'cuda'
template = [lambda x : f"a photo generated by {x}."]
base_model, preprocess = clip.load(args.model, 'cuda', jit=False)
train_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
test_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.ToTensor()])
train_data = datasets.ImageFolder(args.data_location + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(args.data_location + '/test', transform=test_transforms)
train_dset = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle = True)
test_dset = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers)
clf = zeroshot_classifier(base_model, ['humans', 'AI'], template, DEVICE)
NUM_CLASSES = 2
feature_dim = base_model.visual.output_dim
model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf, checkpoint_path = args.checkpoint_path)
for p in model.parameters():
p.data = p.data.float()
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
model = torch.nn.DataParallel(model, device_ids=devices)
model_parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd)
num_batches = len(train_dset)
scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
loss_fn = torch.nn.CrossEntropyLoss()
model_path = os.path.join(args.model_location, f'{args.name}.pt')
print('Saving model to', model_path)
torch.save(model.module.state_dict(), model_path)
last_accuracy = 0.0
for epoch in range(args.epochs):
# Train
model.train()
end = time.time()
for i, batch in enumerate(train_dset):
step = i + epoch * num_batches
scheduler(step)
optimizer.zero_grad()
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
data_time = time.time() - end
logits = model(inputs)
loss = loss_fn(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
batch_time = time.time() - end
end = time.time()
if i % 20 == 0:
percent_complete = 100.0 * i / len(train_dset)
print(
f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_dset)}]\t"
f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
)
## Evaluate
test_loader = test_dset
model.eval()
with torch.no_grad():
print('*'*80)
print('Starting eval')
correct, count = 0.0, 0.0
pbar = tqdm(test_loader)
for batch in pbar:
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
logits = model(inputs)
loss = loss_fn(logits, labels)
pred = logits.argmax(dim=1, keepdim=True)
correct += pred.eq(labels.view_as(pred)).sum().item()
count += len(logits)
pbar.set_description(
f"Val loss: {loss.item():.4f} Acc: {100*correct/count:.2f}")
top1 = correct / count
print(f'Val acc at epoch {epoch}: {100*top1:.2f}')
curr_acc = 100*top1
if curr_acc > last_accuracy:
print('Current acc: {}, Last acc: {}'.format(curr_acc, last_accuracy))
last_accuracy = curr_acc
model_path = os.path.join(args.model_location, f'{args.name}.pt')
print('Saving model to', model_path)
torch.save(model.module.state_dict(), model_path)
else:
print('Not saving the model')