Spaces:
Sleeping
Sleeping
# Code for Peekaboo | |
# Author: Hasib Zunair | |
# Modified from https://github.com/valeoai/FOUND | |
"""Training code for Peekaboo""" | |
import os | |
import sys | |
import json | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from tqdm import tqdm | |
from model import PeekabooModel | |
from evaluation.saliency import evaluate_saliency | |
from misc import batch_apply_bilateral_solver, set_seed, load_config, Logger | |
from datasets.datasets import build_dataset | |
def get_argparser(): | |
parser = argparse.ArgumentParser( | |
description="Training of Peekaboo", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument("--exp-name", type=str, default=None, help="Exp name.") | |
parser.add_argument( | |
"--log-dir", type=str, default="outputs", help="Logging and output directory." | |
) | |
parser.add_argument( | |
"--dataset-dir", | |
type=str, | |
required=True, | |
help="Root directories of training and evaluation datasets.", | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="configs/peekaboo_DUTS-TR.yaml", | |
help="Path of config file.", | |
) | |
parser.add_argument( | |
"--save-model-freq", type=int, default=250, help="Frequency of model saving." | |
) | |
parser.add_argument( | |
"--visualization-freq", | |
type=int, | |
default=10, | |
help="Frequency of prediction visualization in tensorboard.", | |
) | |
args = parser.parse_args() | |
return args | |
def train_model( | |
model, | |
config, | |
dataset, | |
dataset_dir, | |
visualize_freq=10, | |
save_model_freq=500, | |
tensorboard_log_dir=None, | |
): | |
# Diverse | |
print(f"Data will be saved in {tensorboard_log_dir}") | |
save_dir = tensorboard_log_dir | |
if tensorboard_log_dir is not None: | |
# Logging | |
if not os.path.exists(tensorboard_log_dir): | |
os.makedirs(tensorboard_log_dir) | |
from torch.utils.tensorboard import SummaryWriter | |
writer = SummaryWriter(tensorboard_log_dir) | |
# Deconvolution, train only the decoder | |
sigmoid = nn.Sigmoid() | |
model.decoder.train() | |
model.decoder.to("cuda") | |
################################################################################ | |
# # | |
# Setup loss, optimizer and scheduler # | |
# # | |
################################################################################ | |
criterion = nn.BCEWithLogitsLoss() | |
criterion_mse = nn.MSELoss() | |
optimizer = torch.optim.AdamW(model.decoder.parameters(), lr=config.training["lr0"]) | |
scheduler = torch.optim.lr_scheduler.StepLR( | |
optimizer, | |
step_size=config.training["step_lr_size"], | |
gamma=config.training["step_lr_gamma"], | |
) | |
################################################################################ | |
# # | |
# Dataset # | |
# # | |
################################################################################ | |
trainloader = torch.utils.data.DataLoader( | |
dataset, batch_size=config.training["batch_size"], shuffle=True, num_workers=2 | |
) | |
################################################################################ | |
# # | |
# Training loop # | |
# # | |
################################################################################ | |
n_iter = 0 | |
for epoch in range(config.training["nb_epochs"]): | |
running_loss = 0.0 | |
tbar = tqdm(enumerate(trainloader, 0), leave=None) | |
for i, data in tbar: | |
# Get the inputs | |
inputs, masked_inputs, _, input_nonorm, masked_input_nonorm, _, _ = data | |
######## For debug ####### | |
# def to_img(ten): | |
# #ten =(input_nonorm[0].permute(1,2,0).detach().cpu().numpy()+1)/2 | |
# ten =(ten.permute(1,2,0).detach().cpu().numpy()) | |
# ten=(ten*255).astype(np.uint8) | |
# #ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR) | |
# return ten | |
# import pdb; pdb.set_trace() | |
# im = to_img(input_nonorm[0]) | |
# plt.imshow(im); plt.show() | |
# Inputs and masked inputs | |
inputs = inputs.to("cuda") | |
masked_inputs = masked_inputs.to("cuda") | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
################################################################################ | |
# # | |
# Unsupervised Segmenter # | |
# # | |
################################################################################ | |
# Get predictions | |
preds = model(inputs) | |
# Binarization | |
preds_mask = (sigmoid(preds.detach()) > 0.5).float() | |
# Apply bilateral solver | |
preds_mask_bs, _ = batch_apply_bilateral_solver(data, preds_mask.detach()) | |
# Flatten | |
flat_preds = preds.permute(0, 2, 3, 1).reshape(-1, 1) | |
#### Compute unsupervised segmenter loss #### | |
alpha = 1.5 | |
preds_bs_loss = alpha * criterion( | |
flat_preds, preds_mask_bs.reshape(-1).float()[:, None] | |
) | |
print(preds_bs_loss) | |
writer.add_scalar("Loss/L_seg", preds_bs_loss, n_iter) | |
loss = preds_bs_loss | |
################################################################################ | |
# # | |
# Masked Feature Predictor (MFP) # | |
# # | |
################################################################################ | |
# Get predictions | |
preds_mfp = model(masked_inputs) | |
# Binarization | |
preds_mask_mfp = (sigmoid(preds_mfp.detach()) > 0.5).float() | |
# Apply bilateral solver | |
preds_mask_mfp_bs, _ = batch_apply_bilateral_solver( | |
data, preds_mask_mfp.detach() | |
) | |
# Flatten | |
flat_preds_mfp = preds_mfp.permute(0, 2, 3, 1).reshape(-1, 1) | |
#### Compute masked feature predictor loss #### | |
beta = 1.0 | |
preds_bs_cb_loss = beta * criterion( | |
flat_preds_mfp, preds_mask_mfp_bs.reshape(-1).float()[:, None] | |
) | |
writer.add_scalar("Loss/L_mfp", preds_bs_cb_loss, n_iter) | |
loss += preds_bs_cb_loss | |
################################################################################ | |
# # | |
# Predictor Consistency Loss (PCL) # | |
# # | |
################################################################################ | |
gamma = 1.0 | |
task_sim_loss = gamma * criterion_mse( | |
preds_mask_bs.reshape(-1).float()[:, None], | |
preds_mask_mfp_bs.reshape(-1).float()[:, None], | |
) | |
writer.add_scalar("Loss/L_pcl", task_sim_loss, n_iter) | |
loss += task_sim_loss | |
### Compute loss between soft masks and their binarized versions #### | |
self_loss = criterion(flat_preds, preds_mask.reshape(-1).float()[:, None]) | |
self_loss = self_loss * 4.0 | |
loss += self_loss | |
writer.add_scalar("Loss/L_regularization", self_loss, n_iter) | |
################################################################################ | |
# # | |
# Update weights and scheduler step # | |
# # | |
################################################################################ | |
loss.backward() | |
optimizer.step() | |
writer.add_scalar("Loss/total_loss", loss, n_iter) | |
writer.add_scalar("params/lr", optimizer.param_groups[0]["lr"], n_iter) | |
scheduler.step() | |
################################################################################ | |
# # | |
# Visualize predictions and show stats # | |
# # | |
################################################################################ | |
# Visualize predictions in tensorboard | |
if n_iter % visualize_freq == 0: | |
# images and predictions | |
grid = torchvision.utils.make_grid(input_nonorm[:5]) | |
writer.add_image("training/images", grid, n_iter) | |
p_grid = torchvision.utils.make_grid(preds_mask[:5]) | |
writer.add_image("training/preds", p_grid, n_iter) | |
# masked images and predictions | |
m_grid = torchvision.utils.make_grid(masked_input_nonorm[:5]) | |
writer.add_image("training/masked_images", m_grid, n_iter) | |
mp_grid = torchvision.utils.make_grid(preds_mask_mfp[:5]) | |
writer.add_image("training/masked_preds", mp_grid, n_iter) | |
# Statistics | |
running_loss += loss.item() | |
tbar.set_description( | |
f"{dataset.name}| train | iter {n_iter} | loss: ({running_loss / (i + 1):.3f}) " | |
) | |
################################################################################ | |
# # | |
# Save model and evaluate # | |
# # | |
################################################################################ | |
# Save model | |
if n_iter % save_model_freq == 0 and n_iter > 0: | |
model.decoder_save_weights(save_dir, n_iter) | |
# Evaluation | |
if n_iter % config.evaluation["freq"] == 0 and n_iter > 0: | |
for dataset_eval_name in config.evaluation["datasets"]: | |
val_dataset = build_dataset( | |
root_dir=dataset_dir, | |
dataset_name=dataset_eval_name, | |
for_eval=True, | |
dataset_set=None, | |
) | |
evaluate_saliency( | |
val_dataset, model=model, n_iter=n_iter, writer=writer | |
) | |
if n_iter == config.training["max_iter"]: | |
model.decoder_save_weights(save_dir, n_iter) | |
print("\n----" "\nTraining done.") | |
writer.close() | |
return model | |
n_iter += 1 | |
print(f"##### Number of epoch is {epoch} and n_iter is {n_iter} #####") | |
# Save model | |
model.decoder_save_weights(save_dir, n_iter) | |
print("\n----" "\nTraining done.") | |
writer.close() | |
return model | |
def main(): | |
########## Get arguments ########## | |
args = get_argparser() | |
########## Setup ########## | |
# Load config yaml file | |
config, config_ = load_config(args.config) | |
# Experiment name | |
exp_name = "{}-{}{}".format( | |
config.training["dataset"], config.model["arch"], config.model["patch_size"] | |
) | |
if args.exp_name is not None: | |
exp_name = f"{args.exp_name}-{exp_name}" | |
# Log dir | |
output_dir = os.path.join(args.log_dir, exp_name) | |
# Logging | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Save config | |
with open(f"{output_dir}/config.json", "w") as f: | |
print(f"Config saved in {output_dir}/config.json.") | |
json.dump(args.__dict__, f) | |
# Save output of terminal in log file | |
sys.stdout = Logger(os.path.join(output_dir, "log_train.txt")) | |
arguments = str(args).split(", ") | |
print("=========================\nConfigs:{}\n=========================") | |
for i in range(len(arguments)): | |
print(arguments[i]) | |
print( | |
"Hyperparameters from config file: " | |
+ ", ".join(f"{k}={v}" for k, v in config_.items()) | |
) | |
print("=========================") | |
########## Reproducibility ########## | |
set_seed(config.training["seed"]) | |
########## Build training set ########## | |
dataset = build_dataset( | |
root_dir=args.dataset_dir, | |
dataset_name=config.training["dataset"], | |
dataset_set=config.training["dataset_set"], | |
config=config, | |
for_eval=False, | |
) | |
dataset_set = config.training["dataset_set"] | |
str_set = dataset_set if dataset_set is not None else "" | |
print(f"\nBuilding dataset {dataset.name}{str_set} of {len(dataset)}") | |
########## Define Peekaboo ########## | |
model = PeekabooModel( | |
vit_model=config.model["pre_training"], | |
vit_arch=config.model["arch"], | |
vit_patch_size=config.model["patch_size"], | |
enc_type_feats=config.peekaboo["feats"], | |
) | |
########## Training and evaluation ########## | |
print(f"\nStarted training on {dataset.name} [tensorboard dir: {output_dir}]") | |
model = train_model( | |
model=model, | |
config=config, | |
dataset=dataset, | |
dataset_dir=args.dataset_dir, | |
tensorboard_log_dir=output_dir, | |
visualize_freq=args.visualization_freq, | |
save_model_freq=args.save_model_freq, | |
) | |
print(f"\nTraining done, Peekaboo model saved in {output_dir}.") | |
if __name__ == "__main__": | |
main() | |