from typing import List import torch from torch import Tensor from torchmetrics import Metric from .utils import * # motion reconstruction metric class PredMetrics(Metric): def __init__(self, cfg, njoints: int = 22, jointstype: str = "mmm", force_in_meter: bool = True, align_root: bool = True, dist_sync_on_step=True, task: str = "pred", **kwargs): super().__init__(dist_sync_on_step=dist_sync_on_step) self.name = 'Motion Prdiction' self.cfg = cfg self.jointstype = jointstype self.align_root = align_root self.task = task self.force_in_meter = force_in_meter self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("APD", default=torch.tensor([0.0]), dist_reduce_fx="sum") self.add_state("ADE", default=torch.tensor([0.0]), dist_reduce_fx="sum") self.add_state("FDE", default=torch.tensor([0.0]), dist_reduce_fx="sum") self.MR_metrics = ["APD", "ADE", "FDE"] # All metric self.metrics = self.MR_metrics def compute(self, sanity_flag): count = self.count count_seq = self.count_seq mr_metrics = {} mr_metrics["APD"] = self.APD / count_seq mr_metrics["ADE"] = self.ADE / count_seq mr_metrics["FDE"] = self.FDE / count_seq # Reset self.reset() return mr_metrics def update(self, joints_rst: Tensor, joints_ref: Tensor, lengths: List[int]): assert joints_rst.shape == joints_ref.shape assert joints_rst.dim() == 4 # (bs, seq, njoint=22, 3) self.count += sum(lengths) self.count_seq += len(lengths) rst = torch.flatten(joints_rst, start_dim=2) ref = torch.flatten(joints_ref, start_dim=2) for i, l in enumerate(lengths): if self.task == "pred": pred_start = int(l*self.cfg.ABLATION.predict_ratio) diff = rst[i,pred_start:] - ref[i,pred_start:] elif self.task == "inbetween": inbetween_start = int(l*self.cfg.ABLATION.inbetween_ratio) inbetween_end = l - int(l*self.cfg.ABLATION.inbetween_ratio) diff = rst[i,inbetween_start:inbetween_end] - ref[i,inbetween_start:inbetween_end] else: print(f"Task {self.task} not implemented.") diff = rst - ref dist = torch.linalg.norm(diff, dim=-1)[None] ade = dist.mean(dim=1) fde = dist[:,-1] self.ADE = self.ADE + ade self.FDE = self.FDE + fde