Upload adversarial_training_clip_with_object_token.py
Browse files
train/adversarial_training_clip_with_object_token.py
CHANGED
@@ -108,6 +108,8 @@ def main(args):
|
|
108 |
assert str(args.start_step) in args.optimizer_state
|
109 |
assert args.pretrained in ['', 'none']
|
110 |
args.pretrained = args.optimizer_state.replace('_opt', '')
|
|
|
|
|
111 |
model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
|
112 |
|
113 |
# Remove the Normalize transform by creating a new Compose object
|
@@ -128,6 +130,9 @@ def main(args):
|
|
128 |
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
129 |
model_slots = DINOSAURpp(cfg_dict)
|
130 |
proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
|
|
|
|
|
|
131 |
|
132 |
|
133 |
# get data
|
@@ -505,13 +510,13 @@ def train_one_epoch(
|
|
505 |
wandb.log(log_data)
|
506 |
|
507 |
# save 10 models over the course of training
|
508 |
-
if args.save_checkpoints and (step_total % (args.steps //
|
509 |
# save model and optimizer state_dict
|
510 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
511 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
|
512 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
513 |
# every 200 steps, save a fallback model, which gets overwritten
|
514 |
-
if step_total %
|
515 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
516 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
|
517 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
@@ -523,7 +528,7 @@ def train_one_epoch(
|
|
523 |
if step_total >= args.steps:
|
524 |
break
|
525 |
|
526 |
-
torch.cuda.empty_cache()
|
527 |
return step_total
|
528 |
|
529 |
|
|
|
108 |
assert str(args.start_step) in args.optimizer_state
|
109 |
assert args.pretrained in ['', 'none']
|
110 |
args.pretrained = args.optimizer_state.replace('_opt', '')
|
111 |
+
args.pretrained_proj_head = args.optimizer_state.replace('_opt', '_proj_head')
|
112 |
+
|
113 |
model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
|
114 |
|
115 |
# Remove the Normalize transform by creating a new Compose object
|
|
|
130 |
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
131 |
model_slots = DINOSAURpp(cfg_dict)
|
132 |
proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
133 |
+
if args.optimizer_state != '':
|
134 |
+
proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
|
135 |
+
|
136 |
|
137 |
|
138 |
# get data
|
|
|
510 |
wandb.log(log_data)
|
511 |
|
512 |
# save 10 models over the course of training
|
513 |
+
if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
|
514 |
# save model and optimizer state_dict
|
515 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
516 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj_head.pt')
|
517 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
518 |
# every 200 steps, save a fallback model, which gets overwritten
|
519 |
+
if step_total % 2000 == 0:
|
520 |
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
521 |
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj_head.pt')
|
522 |
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
|
|
528 |
if step_total >= args.steps:
|
529 |
break
|
530 |
|
531 |
+
# torch.cuda.empty_cache()
|
532 |
return step_total
|
533 |
|
534 |
|