robustvlm-object-centric / train /training_clip_get_features.py
xmutly's picture
Upload 294 files
e1aaaac verified
import sys
from train.datasets import COCOFlickrDataset, ImageNetDataset
from CLIP_eval.eval_utils import load_clip_model
sys.path.append("open_flamingo")
import os
import shutil
import time
import string
import random
import numpy as np
import open_clip
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from training.scheduler import cosine_lr
from torchvision import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
from train.pgd_train import pgd
from train.apgd_train import apgd_train as apgd
import wandb
from train.utils import init_wandb, AverageMeter
from train.sam_data import SamData
from open_flamingo.eval.models.utils import unwrap_model
from train.utils import str2bool
from slots.DINOSAUR import DINOSAURpp
import matplotlib.pyplot as plt
from einops import rearrange, repeat
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
parser.add_argument('--pretrained', type=str, default='openai')
parser.add_argument('--dataset', type=str, default='imagenet')
parser.add_argument('--template', type=str, default='std')
parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized')
parser.add_argument('--start_step', type=int, default=0, help='Start step for training')
parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path')
parser.add_argument('--steps', type=int, default=20000, help='Number of training steps')
parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--loss', type=str, default='l2', help='ce, l2')
parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2')
parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss')
parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES')
parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw')
parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay')
parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training')
parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation')
parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack')
parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)')
parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
parser.add_argument('--experiment_name', type=str, default='')
parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory')
parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency')
parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency')
parser.add_argument('--output_dir', type=str, default='', help='Output directory')
parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints')
parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')
def main(args):
# setup wandb
if args.wandb:
init_wandb(
project_name='clip-finetune',
model_name=args.finetuned_model_name,
config=vars(args)
)
else:
wandb.init(mode='disabled')
# print args
print(f"Arguments:\n{'-' * 20}")
for arg, value in vars(args).items():
print(f"{arg}: {value}")
print(f"{'-' * 20}")
# setup dirs
if args.overwrite:
shutil.rmtree(args.output_dir, ignore_errors=True)
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False)
# write args to file
with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f:
f.write(str(args))
main_device = 0
# get models
from open_clip.model import CLIPVisionCfg
CLIPVisionCfg.output_tokens = True
model_orig, _, image_processor = open_clip.create_model_and_transforms(
args.clip_model_name, pretrained='openai'#, output_tokens=True # 可选 output_tokens=True,返回token + patches
)
# Remove the Normalize transform by creating a new Compose object
preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
normalize = image_processor.transforms[-1]
del image_processor
print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
####################################################### get slot-attention model #########################################################
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
model_slots = DINOSAURpp(cfg_dict)
# get data
if args.dataset == 'imagenet':
dataset = ImageNetDataset(
root=args.imagenet_root + '/train',
transform=preprocessor_without_normalize,
)
elif args.dataset == 'segment_anything':
dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize)
print(dataset.__len__())
elif args.dataset == 'coco':
if os.path.exists('/mnt/datasets/coco'):
image_dir_path = '/mnt/datasets/coco/train2017'
annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json'
elif os.path.exists('/mnt/lustre'):
image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017'
annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json'
else:
raise ValueError('COCO dataset not found')
dataset = COCOFlickrDataset(
image_dir_path=image_dir_path,
annotations_path=annotations_path,
transform=preprocessor_without_normalize
)
dataset_eval = ImageNetDataset(
root=args.imagenet_root + '/val',
transform=preprocessor_without_normalize,
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
# Get text label embeddings of all ImageNet classes
if args.template == 'std':
template = 'This is a photo of a {}'
elif args.template == 'blurry':
template = 'This is a blurry photo of a {}'
else:
raise ValueError(f'Unknown template: {args.template}')
print(f'template: {template}')
texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
text_tokens = open_clip.tokenize(texts)
model_orig.to(main_device)
with torch.no_grad():
embedding_text_labels_norm = []
for el in (text_tokens[:500], text_tokens[500:]):
# we need to split the text tokens into two batches because otherwise we run out of memory
# note that we are accessing the model directly here, not the CustomModel wrapper
# thus its always normalizing the text embeddings
embedding_text_labels_norm.append(
model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
)
embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
assert torch.allclose(
F.normalize(embedding_text_labels_norm, dim=0),
embedding_text_labels_norm
)
if args.clip_model_name == 'ViT-B-32':
assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape
elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape
else:
raise ValueError(f'Unknown model: {args.clip_model_name}')
model_orig.cpu()
model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
if num_gpus > 1:
model_orig = torch.nn.DataParallel(model_orig)
model_orig.cuda()
model_slots = model_slots
if num_gpus > 1:
model_slots = torch.nn.DataParallel(model_slots)
model_slots.cuda()
# set optimizer (all params have requires_grad=True)
params = unwrap_model(model_slots).parameters()
if args.opt == 'adamw':
optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
elif args.opt == 'sgd':
optimizer = torch.optim.SGD(
params,
lr=args.lr,
momentum=args.momentum_sgd,
weight_decay=args.wd
)
else:
raise ValueError(f'Optimizer {args.optimizer} not supported.')
if args.optimizer_state != '':
optimizer.load_state_dict(torch.load(args.optimizer_state))
# set scheduler
scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps)
# compute amount of epochs
total_epochs = args.steps / len(dataloader)
print(f'train for {total_epochs} epochs')
args.total_epochs = total_epochs
# finetune
step_total = args.start_step
epoch = 0
step_total = train_one_epoch_slots(
step_total,
model_slots=model_slots,
model_orig=model_orig,
dataloader=dataloader,
dataloader_eval=dataloader_eval,
optimizer=optimizer,
scheduler=scheduler,
embedding_text_labels_norm=embedding_text_labels_norm,
normalize=normalize,
args=args,
epoch=epoch
)
print(f'Epoch {epoch} done.')
epoch += 1
class ClipVisionModel(torch.nn.Module):
def __init__(self, model, args, normalize):
super().__init__()
self.model = model
self.args = args
self.normalize = normalize
def forward(self, vision, output_normalize):
vision = self.normalize(vision)
embedding, patches = self.model(vision)
if output_normalize:
embedding = F.normalize(embedding, dim=-1)
return embedding, patches
class ComputeLossWrapper:
def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None,
logit_scale=100.):
self.embedding_orig = embedding_orig
self.embedding_text_labels_norm = embedding_text_labels_norm
self.reduction = reduction
self.loss_str = loss
self.logit_scale = logit_scale
def __call__(self, embedding, targets):
return compute_loss(
loss_str=self.loss_str, embedding=embedding, targets=targets,
embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
)
def train_one_epoch_slots(
step_total, model_slots, model_orig, dataloader, optimizer, scheduler, normalize,
embedding_text_labels_norm, args, epoch, dataloader_eval=None
):
model_orig.eval()
model_slots.eval()
MSEFunc = torch.nn.MSELoss()
loss_meter = AverageMeter('loss')
epoch_start_time = time.time()
for i, (data, targets) in tqdm(enumerate(dataloader)):
is_classification = isinstance(targets, torch.Tensor)
data = data.cuda()
n_samples = data.shape[0]
if is_classification:
targets = targets.cuda()
with torch.no_grad():
embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
if num_gpus > 1:
patches_orig = model_orig.module.model.ln_pre(patches_orig)
else:
patches_orig = model_orig.model.ln_pre(patches_orig)
for j in range(patches_orig.size(0)):
store_npy = patches_orig[j].detach().cpu().numpy()
label = targets[j].detach().cpu().numpy()
store_name = 'class{}_batch{}_sample{}.npy'.format(label, i, j)
store_path = os.path.join('/home/tly/RobustVLM/datasets/imagenet_features', str(label))
os.makedirs(store_path, exist_ok=True)
np.save(os.path.join('/home/tly/RobustVLM/datasets/imagenet_features', str(label), store_name), store_npy)
np.savez_compressed(os.path.join('/home/tly/RobustVLM/datasets/imagenet_features', str(label), store_name), x=store_npy) # 保存数组并压缩
torch.cuda.empty_cache()
return step_total
@torch.no_grad()
def compute_acc(logits, targets):
preds_clean = logits.max(dim=1)[1].detach()
acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100
return acc
def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
embedding_text_labels_norm=None, reduction='mean'):
if loss_str == 'l2':
loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
elif loss_str == 'ce':
loss = ce(
out=embedding @ (logit_scale * embedding_text_labels_norm),
targets=targets,
reduction=reduction
)
else:
raise ValueError(f'loss {loss_str} not supported')
return loss
def l2(out, targets, reduction='none'):
# squared l2 - it does not divide by the latent dimension
# should have shape (batch_size, embedding_size)
assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
assert out.shape[0] > 1
# Compute the element-wise squared error
squared_error_batch = F.mse_loss(out, targets, reduction='none')
if reduction == 'mean':
squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
else:
squared_error_batch = squared_error_batch.sum(dim=1)
assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
return squared_error_batch
def ce(out, targets, reduction='mean'):
# out = logits
assert out.shape[0] == targets.shape[0], (out.shape, targets.shape)
assert out.shape[0] > 1
return F.cross_entropy(out, targets, reduction=reduction)
if __name__ == '__main__':
# set seeds
torch.manual_seed(0)
np.random.seed(0)
# Parse command-line arguments
args = parser.parse_args()
args.eps /= 255
args.stepsize_adv /= 255
# make sure there is no string in args that should be a bool
assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}'
assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq'
if args.devices != '':
# set cuda visible devices
os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
print(f'Number of GPUs available: {num_gpus}')
else:
print('No multiple GPUs available.')
# set model name and output dir
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5))
args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}'
args.finetuned_model_name = args.finetuned_model_name.replace('/', '_')
args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name)
# run
main(args)