# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import defaultdict import os from typing import Optional import torch from tqdm import tqdm import numpy as np from torch.utils.tensorboard import SummaryWriter from cotracker.datasets.utils import dataclass_to_cuda_ from cotracker.utils.visualizer import Visualizer from cotracker.models.core.model_utils import reduce_masked_mean from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics import logging class Evaluator: """ A class defining the CoTracker evaluator. """ def __init__(self, exp_dir) -> None: # Visualization self.exp_dir = exp_dir os.makedirs(exp_dir, exist_ok=True) self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) self.visualize_dir = os.path.join(exp_dir, "visualisations") def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): if isinstance(pred_trajectory, tuple): pred_trajectory, pred_visibility = pred_trajectory else: pred_visibility = None if "tapvid" in dataset_name: B, T, N, D = sample.trajectory.shape traj = sample.trajectory.clone() thr = 0.9 if pred_visibility is None: logging.warning("visibility is NONE") pred_visibility = torch.zeros_like(sample.visibility) if not pred_visibility.dtype == torch.bool: pred_visibility = pred_visibility > thr query_points = sample.query_points.clone().cpu().numpy() pred_visibility = pred_visibility[:, :, :N] pred_trajectory = pred_trajectory[:, :, :N] gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() gt_occluded = ( torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy() ) pred_occluded = ( torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy() ) pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() out_metrics = compute_tapvid_metrics( query_points, gt_occluded, gt_tracks, pred_occluded, pred_tracks, query_mode="strided" if "strided" in dataset_name else "first", ) metrics[sample.seq_name[0]] = out_metrics for metric_name in out_metrics.keys(): if "avg" not in metrics: metrics["avg"] = {} metrics["avg"][metric_name] = np.mean( [v[metric_name] for k, v in metrics.items() if k != "avg"] ) logging.info(f"Metrics: {out_metrics}") logging.info(f"avg: {metrics['avg']}") print("metrics", out_metrics) print("avg", metrics["avg"]) elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey": *_, N, _ = sample.trajectory.shape B, T, N = sample.visibility.shape H, W = sample.video.shape[-2:] device = sample.video.device out_metrics = {} d_vis_sum = d_occ_sum = d_sum_all = 0.0 thrs = [1, 2, 4, 8, 16] sx_ = (W - 1) / 255.0 sy_ = (H - 1) / 255.0 sc_py = np.array([sx_, sy_]).reshape([1, 1, 2]) sc_pt = torch.from_numpy(sc_py).float().to(device) __, first_visible_inds = torch.max(sample.visibility, dim=1) frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N) start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1)) for thr in thrs: d_ = ( torch.norm( pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, dim=-1, ) < thr ).float() # B,S-1,N d_occ = ( reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item() * 100.0 ) d_occ_sum += d_occ out_metrics[f"accuracy_occ_{thr}"] = d_occ d_vis = ( reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0 ) d_vis_sum += d_vis out_metrics[f"accuracy_vis_{thr}"] = d_vis d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0 d_sum_all += d_all out_metrics[f"accuracy_{thr}"] = d_all d_occ_avg = d_occ_sum / len(thrs) d_vis_avg = d_vis_sum / len(thrs) d_all_avg = d_sum_all / len(thrs) sur_thr = 50 dists = torch.norm( pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, dim=-1, ) # B,S,N dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N survival = torch.cumprod(dist_ok, dim=1) # B,S,N out_metrics["survival"] = torch.mean(survival).item() * 100.0 out_metrics["accuracy_occ"] = d_occ_avg out_metrics["accuracy_vis"] = d_vis_avg out_metrics["accuracy"] = d_all_avg metrics[sample.seq_name[0]] = out_metrics for metric_name in out_metrics.keys(): if "avg" not in metrics: metrics["avg"] = {} metrics["avg"][metric_name] = float( np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"]) ) logging.info(f"Metrics: {out_metrics}") logging.info(f"avg: {metrics['avg']}") print("metrics", out_metrics) print("avg", metrics["avg"]) @torch.no_grad() def evaluate_sequence( self, model, test_dataloader: torch.utils.data.DataLoader, dataset_name: str, train_mode=False, visualize_every: int = 1, writer: Optional[SummaryWriter] = None, step: Optional[int] = 0, ): metrics = {} vis = Visualizer( save_dir=self.exp_dir, fps=7, ) for ind, sample in enumerate(tqdm(test_dataloader)): if isinstance(sample, tuple): sample, gotit = sample if not all(gotit): print("batch is None") continue if torch.cuda.is_available(): dataclass_to_cuda_(sample) device = torch.device("cuda") else: device = torch.device("cpu") if ( not train_mode and hasattr(model, "sequence_len") and (sample.visibility[:, : model.sequence_len].sum() == 0) ): print(f"skipping batch {ind}") continue if "tapvid" in dataset_name: queries = sample.query_points.clone().float() queries = torch.stack( [ queries[:, :, 0], queries[:, :, 2], queries[:, :, 1], ], dim=2, ).to(device) else: queries = torch.cat( [ torch.zeros_like(sample.trajectory[:, 0, :, :1]), sample.trajectory[:, 0], ], dim=2, ).to(device) pred_tracks = model(sample.video, queries) if "strided" in dataset_name: inv_video = sample.video.flip(1).clone() inv_queries = queries.clone() inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 pred_trj, pred_vsb = pred_tracks inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) inv_pred_trj = inv_pred_trj.flip(1) inv_pred_vsb = inv_pred_vsb.flip(1) mask = pred_trj == 0 pred_trj[mask] = inv_pred_trj[mask] pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] pred_tracks = pred_trj, pred_vsb if dataset_name == "badja" or dataset_name == "fastcapture": seq_name = sample.seq_name[0] else: seq_name = str(ind) if ind % visualize_every == 0: vis.visualize( sample.video, pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, filename=dataset_name + "_" + seq_name, writer=writer, step=step, ) self.compute_metrics(metrics, sample, pred_tracks, dataset_name) return metrics