from typing import List import torch from torch import Tensor from torchmetrics import Metric from torchmetrics.functional import pairwise_euclidean_distance from .utils import * import os from mGPT.config import instantiate_from_config class MMMetrics(Metric): full_state_update = True def __init__(self, cfg, dataname='humanml3d', mm_num_times=10, dist_sync_on_step=True, **kwargs): super().__init__(dist_sync_on_step=dist_sync_on_step) self.name = "MultiModality scores" self.cfg = cfg self.dataname = dataname self.mm_num_times = mm_num_times self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum") self.metrics = ["MultiModality"] self.add_state("MultiModality", default=torch.tensor(0.), dist_reduce_fx="sum") # chached batches self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None) # T2M Evaluator self._get_t2m_evaluator(cfg) def _get_t2m_evaluator(self, cfg): """ load T2M text encoder and motion encoder for evaluating """ # init module self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) # load pretrianed if self.dataname == "kit": dataname = "kit" else: dataname = "t2m" t2m_checkpoint = torch.load(os.path.join( cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"), map_location="cpu") self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) self.t2m_moveencoder.load_state_dict( t2m_checkpoint["movement_encoder"]) self.t2m_motionencoder.load_state_dict( t2m_checkpoint["motion_encoder"]) # freeze params self.t2m_textencoder.eval() self.t2m_moveencoder.eval() self.t2m_motionencoder.eval() for p in self.t2m_textencoder.parameters(): p.requires_grad = False for p in self.t2m_moveencoder.parameters(): p.requires_grad = False for p in self.t2m_motionencoder.parameters(): p.requires_grad = False def compute(self, sanity_flag): count = self.count.item() count_seq = self.count_seq.item() # init metrics metrics = {metric: getattr(self, metric) for metric in self.metrics} # if in sanity check stage then jump if sanity_flag: return metrics # cat all embeddings all_mm_motions = torch.cat(self.mm_motion_embeddings, axis=0).cpu().numpy() metrics['MultiModality'] = calculate_multimodality_np( all_mm_motions, self.mm_num_times) # Reset self.reset() return {**metrics} def update( self, feats_rst: Tensor, lengths_rst: List[int], ): self.count += sum(lengths_rst) self.count_seq += len(lengths_rst) align_idx = np.argsort(lengths_rst)[::-1].copy() feats_rst = feats_rst[align_idx] lengths_rst = np.array(lengths_rst)[align_idx] recmotion_embeddings = self.get_motion_embeddings( feats_rst, lengths_rst) cache = [0] * len(lengths_rst) for i in range(len(lengths_rst)): cache[align_idx[i]] = recmotion_embeddings[i:i + 1] mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0) # self.mm_motion_embeddings.extend(cache) # print(mm_motion_embeddings.shape) # # store all mm motion embeddings self.mm_motion_embeddings.append(mm_motion_embeddings) def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): m_lens = torch.tensor(lengths) m_lens = torch.div(m_lens, self.cfg.DATASET.HUMANML3D.UNIT_LEN, rounding_mode="floor") mov = self.t2m_moveencoder(feats[..., :-4]).detach() emb = self.t2m_motionencoder(mov, m_lens) # [bs, nlatent*ndim] <= [bs, nlatent, ndim] return torch.flatten(emb, start_dim=1).detach()