Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import random | |
import torch | |
import signal | |
import socket | |
import sys | |
import json | |
import numpy as np | |
import argparse | |
import logging | |
from pathlib import Path | |
from tqdm import tqdm | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torch.cuda.amp import GradScaler | |
from torch.utils.tensorboard import SummaryWriter | |
from pytorch_lightning.lite import LightningLite | |
from cotracker.models.evaluation_predictor import EvaluationPredictor | |
from cotracker.models.core.cotracker.cotracker import CoTracker2 | |
from cotracker.utils.visualizer import Visualizer | |
from cotracker.datasets.tap_vid_datasets import TapVidDataset | |
from cotracker.datasets.dr_dataset import DynamicReplicaDataset | |
from cotracker.evaluation.core.evaluator import Evaluator | |
from cotracker.datasets import kubric_movif_dataset | |
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_ | |
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss | |
# define the handler function | |
# for training on a slurm cluster | |
def sig_handler(signum, frame): | |
print("caught signal", signum) | |
print(socket.gethostname(), "USR1 signal caught.") | |
# do other stuff to cleanup here | |
print("requeuing job " + os.environ["SLURM_JOB_ID"]) | |
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) | |
sys.exit(-1) | |
def term_handler(signum, frame): | |
print("bypassing sigterm", flush=True) | |
def fetch_optimizer(args, model): | |
"""Create the optimizer and learning rate scheduler""" | |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) | |
scheduler = optim.lr_scheduler.OneCycleLR( | |
optimizer, | |
args.lr, | |
args.num_steps + 100, | |
pct_start=0.05, | |
cycle_momentum=False, | |
anneal_strategy="linear", | |
) | |
return optimizer, scheduler | |
def forward_batch(batch, model, args): | |
video = batch.video | |
trajs_g = batch.trajectory | |
vis_g = batch.visibility | |
valids = batch.valid | |
B, T, C, H, W = video.shape | |
assert C == 3 | |
B, T, N, D = trajs_g.shape | |
device = video.device | |
__, first_positive_inds = torch.max(vis_g, dim=1) | |
# We want to make sure that during training the model sees visible points | |
# that it does not need to track just yet: they are visible but queried from a later frame | |
N_rand = N // 4 | |
# inds of visible points in the 1st frame | |
nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)] | |
for b in range(B): | |
rand_vis_inds = torch.cat( | |
[ | |
nonzero_row[torch.randint(len(nonzero_row), size=(1,))] | |
for nonzero_row in nonzero_inds[b] | |
], | |
dim=1, | |
) | |
first_positive_inds[b] = torch.cat( | |
[rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1 | |
) | |
ind_array_ = torch.arange(T, device=device) | |
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N) | |
assert torch.allclose( | |
vis_g[ind_array_ == first_positive_inds[:, None, :]], | |
torch.ones(1, device=device), | |
) | |
gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D)) | |
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1) | |
queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2) | |
predictions, visibility, train_data = model( | |
video=video, queries=queries, iters=args.train_iters, is_train=True | |
) | |
coord_predictions, vis_predictions, valid_mask = train_data | |
vis_gts = [] | |
traj_gts = [] | |
valids_gts = [] | |
S = args.sliding_window_len | |
for ind in range(0, args.sequence_len - S // 2, S // 2): | |
vis_gts.append(vis_g[:, ind : ind + S]) | |
traj_gts.append(trajs_g[:, ind : ind + S]) | |
valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S]) | |
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8) | |
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts) | |
output = {"flow": {"predictions": predictions[0].detach()}} | |
output["flow"]["loss"] = seq_loss.mean() | |
output["visibility"] = { | |
"loss": vis_loss.mean() * 10.0, | |
"predictions": visibility[0].detach(), | |
} | |
return output | |
def run_test_eval(evaluator, model, dataloaders, writer, step): | |
model.eval() | |
for ds_name, dataloader in dataloaders: | |
visualize_every = 1 | |
grid_size = 5 | |
if ds_name == "dynamic_replica": | |
visualize_every = 8 | |
grid_size = 0 | |
elif "tapvid" in ds_name: | |
visualize_every = 5 | |
predictor = EvaluationPredictor( | |
model.module.module, | |
grid_size=grid_size, | |
local_grid_size=0, | |
single_point=False, | |
n_iters=6, | |
) | |
if torch.cuda.is_available(): | |
predictor.model = predictor.model.cuda() | |
metrics = evaluator.evaluate_sequence( | |
model=predictor, | |
test_dataloader=dataloader, | |
dataset_name=ds_name, | |
train_mode=True, | |
writer=writer, | |
step=step, | |
visualize_every=visualize_every, | |
) | |
if ds_name == "dynamic_replica" or ds_name == "kubric": | |
metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()} | |
if "tapvid" in ds_name: | |
metrics = { | |
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"], | |
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"], | |
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"], | |
} | |
writer.add_scalars(f"Eval_{ds_name}", metrics, step) | |
class Logger: | |
SUM_FREQ = 100 | |
def __init__(self, model, scheduler): | |
self.model = model | |
self.scheduler = scheduler | |
self.total_steps = 0 | |
self.running_loss = {} | |
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | |
def _print_training_status(self): | |
metrics_data = [ | |
self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys()) | |
] | |
training_str = "[{:6d}] ".format(self.total_steps + 1) | |
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) | |
# print the training status | |
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}") | |
if self.writer is None: | |
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | |
for k in self.running_loss: | |
self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps) | |
self.running_loss[k] = 0.0 | |
def push(self, metrics, task): | |
self.total_steps += 1 | |
for key in metrics: | |
task_key = str(key) + "_" + task | |
if task_key not in self.running_loss: | |
self.running_loss[task_key] = 0.0 | |
self.running_loss[task_key] += metrics[key] | |
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1: | |
self._print_training_status() | |
self.running_loss = {} | |
def write_dict(self, results): | |
if self.writer is None: | |
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | |
for key in results: | |
self.writer.add_scalar(key, results[key], self.total_steps) | |
def close(self): | |
self.writer.close() | |
class Lite(LightningLite): | |
def run(self, args): | |
def seed_everything(seed: int): | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
seed_everything(0) | |
def seed_worker(worker_id): | |
worker_seed = torch.initial_seed() % 2**32 | |
np.random.seed(worker_seed) | |
random.seed(worker_seed) | |
g = torch.Generator() | |
g.manual_seed(0) | |
if self.global_rank == 0: | |
eval_dataloaders = [] | |
if "dynamic_replica" in args.eval_datasets: | |
eval_dataset = DynamicReplicaDataset( | |
sample_len=60, only_first_n_samples=1, rgbd_input=False | |
) | |
eval_dataloader_dr = torch.utils.data.DataLoader( | |
eval_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=1, | |
collate_fn=collate_fn, | |
) | |
eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr)) | |
if "tapvid_davis_first" in args.eval_datasets: | |
data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl") | |
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) | |
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( | |
eval_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=1, | |
collate_fn=collate_fn, | |
) | |
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis)) | |
evaluator = Evaluator(args.ckpt_path) | |
visualizer = Visualizer( | |
save_dir=args.ckpt_path, | |
pad_value=80, | |
fps=1, | |
show_first_frame=0, | |
tracks_leave_trace=0, | |
) | |
if args.model_name == "cotracker": | |
model = CoTracker2( | |
stride=args.model_stride, | |
window_len=args.sliding_window_len, | |
add_space_attn=not args.remove_space_attn, | |
num_virtual_tracks=args.num_virtual_tracks, | |
model_resolution=args.crop_size, | |
) | |
else: | |
raise ValueError(f"Model {args.model_name} doesn't exist") | |
with open(args.ckpt_path + "/meta.json", "w") as file: | |
json.dump(vars(args), file, sort_keys=True, indent=4) | |
model.cuda() | |
train_dataset = kubric_movif_dataset.KubricMovifDataset( | |
data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"), | |
crop_size=args.crop_size, | |
seq_len=args.sequence_len, | |
traj_per_sample=args.traj_per_sample, | |
sample_vis_1st_frame=args.sample_vis_1st_frame, | |
use_augs=not args.dont_use_augs, | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
num_workers=args.num_workers, | |
worker_init_fn=seed_worker, | |
generator=g, | |
pin_memory=True, | |
collate_fn=collate_fn_train, | |
drop_last=True, | |
) | |
train_loader = self.setup_dataloaders(train_loader, move_to_device=False) | |
print("LEN TRAIN LOADER", len(train_loader)) | |
optimizer, scheduler = fetch_optimizer(args, model) | |
total_steps = 0 | |
if self.global_rank == 0: | |
logger = Logger(model, scheduler) | |
folder_ckpts = [ | |
f | |
for f in os.listdir(args.ckpt_path) | |
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f | |
] | |
if len(folder_ckpts) > 0: | |
ckpt_path = sorted(folder_ckpts)[-1] | |
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path)) | |
logging.info(f"Loading checkpoint {ckpt_path}") | |
if "model" in ckpt: | |
model.load_state_dict(ckpt["model"]) | |
else: | |
model.load_state_dict(ckpt) | |
if "optimizer" in ckpt: | |
logging.info("Load optimizer") | |
optimizer.load_state_dict(ckpt["optimizer"]) | |
if "scheduler" in ckpt: | |
logging.info("Load scheduler") | |
scheduler.load_state_dict(ckpt["scheduler"]) | |
if "total_steps" in ckpt: | |
total_steps = ckpt["total_steps"] | |
logging.info(f"Load total_steps {total_steps}") | |
elif args.restore_ckpt is not None: | |
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt") | |
logging.info("Loading checkpoint...") | |
strict = True | |
state_dict = self.load(args.restore_ckpt) | |
if "model" in state_dict: | |
state_dict = state_dict["model"] | |
if list(state_dict.keys())[0].startswith("module."): | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict, strict=strict) | |
logging.info(f"Done loading checkpoint") | |
model, optimizer = self.setup(model, optimizer, move_to_device=False) | |
# model.cuda() | |
model.train() | |
save_freq = args.save_freq | |
scaler = GradScaler(enabled=args.mixed_precision) | |
should_keep_training = True | |
global_batch_num = 0 | |
epoch = -1 | |
while should_keep_training: | |
epoch += 1 | |
for i_batch, batch in enumerate(tqdm(train_loader)): | |
batch, gotit = batch | |
if not all(gotit): | |
print("batch is None") | |
continue | |
dataclass_to_cuda_(batch) | |
optimizer.zero_grad() | |
assert model.training | |
output = forward_batch(batch, model, args) | |
loss = 0 | |
for k, v in output.items(): | |
if "loss" in v: | |
loss += v["loss"] | |
if self.global_rank == 0: | |
for k, v in output.items(): | |
if "loss" in v: | |
logger.writer.add_scalar( | |
f"live_{k}_loss", v["loss"].item(), total_steps | |
) | |
if "metrics" in v: | |
logger.push(v["metrics"], k) | |
if total_steps % save_freq == save_freq - 1: | |
visualizer.visualize( | |
video=batch.video.clone(), | |
tracks=batch.trajectory.clone(), | |
filename="train_gt_traj", | |
writer=logger.writer, | |
step=total_steps, | |
) | |
visualizer.visualize( | |
video=batch.video.clone(), | |
tracks=output["flow"]["predictions"][None], | |
filename="train_pred_traj", | |
writer=logger.writer, | |
step=total_steps, | |
) | |
if len(output) > 1: | |
logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps) | |
logger.writer.add_scalar( | |
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps | |
) | |
global_batch_num += 1 | |
self.barrier() | |
self.backward(scaler.scale(loss)) | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) | |
scaler.step(optimizer) | |
scheduler.step() | |
scaler.update() | |
total_steps += 1 | |
if self.global_rank == 0: | |
if (i_batch >= len(train_loader) - 1) or ( | |
total_steps == 1 and args.validate_at_start | |
): | |
if (epoch + 1) % args.save_every_n_epoch == 0: | |
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps) | |
save_path = Path( | |
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth" | |
) | |
save_dict = { | |
"model": model.module.module.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
"total_steps": total_steps, | |
} | |
logging.info(f"Saving file {save_path}") | |
self.save(save_dict, save_path) | |
if (epoch + 1) % args.evaluate_every_n_epoch == 0 or ( | |
args.validate_at_start and epoch == 0 | |
): | |
run_test_eval( | |
evaluator, | |
model, | |
eval_dataloaders, | |
logger.writer, | |
total_steps, | |
) | |
model.train() | |
torch.cuda.empty_cache() | |
self.barrier() | |
if total_steps > args.num_steps: | |
should_keep_training = False | |
break | |
if self.global_rank == 0: | |
print("FINISHED TRAINING") | |
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth" | |
torch.save(model.module.module.state_dict(), PATH) | |
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps) | |
logger.close() | |
if __name__ == "__main__": | |
signal.signal(signal.SIGUSR1, sig_handler) | |
signal.signal(signal.SIGTERM, term_handler) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_name", default="cotracker", help="model name") | |
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint") | |
parser.add_argument("--ckpt_path", help="path to save checkpoints") | |
parser.add_argument( | |
"--batch_size", type=int, default=4, help="batch size used during training." | |
) | |
parser.add_argument("--num_nodes", type=int, default=1) | |
parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers") | |
parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision") | |
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") | |
parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.") | |
parser.add_argument( | |
"--num_steps", type=int, default=200000, help="length of training schedule." | |
) | |
parser.add_argument( | |
"--evaluate_every_n_epoch", | |
type=int, | |
default=1, | |
help="evaluate during training after every n epochs, after every epoch by default", | |
) | |
parser.add_argument( | |
"--save_every_n_epoch", | |
type=int, | |
default=1, | |
help="save checkpoints during training after every n epochs, after every epoch by default", | |
) | |
parser.add_argument( | |
"--validate_at_start", | |
action="store_true", | |
help="whether to run evaluation before training starts", | |
) | |
parser.add_argument( | |
"--save_freq", | |
type=int, | |
default=100, | |
help="frequency of trajectory visualization during training", | |
) | |
parser.add_argument( | |
"--traj_per_sample", | |
type=int, | |
default=768, | |
help="the number of trajectories to sample for training", | |
) | |
parser.add_argument( | |
"--dataset_root", type=str, help="path lo all the datasets (train and eval)" | |
) | |
parser.add_argument( | |
"--train_iters", | |
type=int, | |
default=4, | |
help="number of updates to the disparity field in each forward pass.", | |
) | |
parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length") | |
parser.add_argument( | |
"--eval_datasets", | |
nargs="+", | |
default=["tapvid_davis_first"], | |
help="what datasets to use for evaluation", | |
) | |
parser.add_argument( | |
"--remove_space_attn", | |
action="store_true", | |
help="remove space attention from CoTracker", | |
) | |
parser.add_argument( | |
"--num_virtual_tracks", | |
type=int, | |
default=None, | |
help="stride of the CoTracker feature network", | |
) | |
parser.add_argument( | |
"--dont_use_augs", | |
action="store_true", | |
help="don't apply augmentations during training", | |
) | |
parser.add_argument( | |
"--sample_vis_1st_frame", | |
action="store_true", | |
help="only sample trajectories with points visible on the first frame", | |
) | |
parser.add_argument( | |
"--sliding_window_len", | |
type=int, | |
default=8, | |
help="length of the CoTracker sliding window", | |
) | |
parser.add_argument( | |
"--model_stride", | |
type=int, | |
default=8, | |
help="stride of the CoTracker feature network", | |
) | |
parser.add_argument( | |
"--crop_size", | |
type=int, | |
nargs="+", | |
default=[384, 512], | |
help="crop videos to this resolution during training", | |
) | |
parser.add_argument( | |
"--eval_max_seq_len", | |
type=int, | |
default=1000, | |
help="maximum length of evaluation videos", | |
) | |
args = parser.parse_args() | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", | |
) | |
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True) | |
from pytorch_lightning.strategies import DDPStrategy | |
Lite( | |
strategy=DDPStrategy(find_unused_parameters=False), | |
devices="auto", | |
accelerator="gpu", | |
precision=32, | |
num_nodes=args.num_nodes, | |
).run(args) | |