Spaces:
Running
on
T4
Running
on
T4
from typing import List | |
import os | |
import torch | |
from torch import Tensor | |
from torchmetrics import Metric | |
from torchmetrics.functional import pairwise_euclidean_distance | |
from .utils import * | |
from mGPT.config import instantiate_from_config | |
class TM2TMetrics(Metric): | |
def __init__(self, | |
cfg, | |
dataname='humanml3d', | |
top_k=3, | |
R_size=32, | |
diversity_times=300, | |
dist_sync_on_step=True, | |
**kwargs): | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.cfg = cfg | |
self.dataname = dataname | |
self.name = "matching, fid, and diversity scores" | |
self.top_k = top_k | |
self.R_size = R_size | |
self.text = 'lm' in cfg.TRAIN.STAGE and cfg.model.params.task == 't2m' | |
self.diversity_times = diversity_times | |
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") | |
self.add_state("count_seq", | |
default=torch.tensor(0), | |
dist_reduce_fx="sum") | |
self.metrics = [] | |
# Matching scores | |
if self.text: | |
self.add_state("Matching_score", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum") | |
self.add_state("gt_Matching_score", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum") | |
self.Matching_metrics = ["Matching_score", "gt_Matching_score"] | |
for k in range(1, top_k + 1): | |
self.add_state( | |
f"R_precision_top_{str(k)}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
self.Matching_metrics.append(f"R_precision_top_{str(k)}") | |
for k in range(1, top_k + 1): | |
self.add_state( | |
f"gt_R_precision_top_{str(k)}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") | |
self.metrics.extend(self.Matching_metrics) | |
# Fid | |
self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.metrics.append("FID") | |
# Diversity | |
self.add_state("Diversity", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum") | |
self.add_state("gt_Diversity", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum") | |
self.metrics.extend(["Diversity", "gt_Diversity"]) | |
# Chached batches | |
self.add_state("text_embeddings", default=[], dist_reduce_fx=None) | |
self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) | |
self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) | |
# T2M Evaluator | |
self._get_t2m_evaluator(cfg) | |
def _get_t2m_evaluator(self, cfg): | |
""" | |
load T2M text encoder and motion encoder for evaluating | |
""" | |
# init module | |
self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) | |
self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) | |
self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) | |
# load pretrianed | |
if self.dataname == "kit": | |
dataname = "kit" | |
else: | |
dataname = "t2m" | |
t2m_checkpoint = torch.load(os.path.join( | |
cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"), | |
map_location="cpu") | |
self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) | |
self.t2m_moveencoder.load_state_dict( | |
t2m_checkpoint["movement_encoder"]) | |
self.t2m_motionencoder.load_state_dict( | |
t2m_checkpoint["motion_encoder"]) | |
# freeze params | |
self.t2m_textencoder.eval() | |
self.t2m_moveencoder.eval() | |
self.t2m_motionencoder.eval() | |
for p in self.t2m_textencoder.parameters(): | |
p.requires_grad = False | |
for p in self.t2m_moveencoder.parameters(): | |
p.requires_grad = False | |
for p in self.t2m_motionencoder.parameters(): | |
p.requires_grad = False | |
def compute(self, sanity_flag): | |
count = self.count.item() | |
count_seq = self.count_seq.item() | |
# Init metrics dict | |
metrics = {metric: getattr(self, metric) for metric in self.metrics} | |
# Jump in sanity check stage | |
if sanity_flag: | |
return metrics | |
# Cat cached batches and shuffle | |
shuffle_idx = torch.randperm(count_seq) | |
all_genmotions = torch.cat(self.recmotion_embeddings, | |
axis=0).cpu()[shuffle_idx, :] | |
all_gtmotions = torch.cat(self.gtmotion_embeddings, | |
axis=0).cpu()[shuffle_idx, :] | |
# Compute text related metrics | |
if self.text: | |
all_texts = torch.cat(self.text_embeddings, | |
axis=0).cpu()[shuffle_idx, :] | |
# Compute r-precision | |
assert count_seq > self.R_size | |
top_k_mat = torch.zeros((self.top_k, )) | |
for i in range(count_seq // self.R_size): | |
# [bs=32, 1*256] | |
group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] | |
# [bs=32, 1*256] | |
group_motions = all_genmotions[i * self.R_size:(i + 1) * | |
self.R_size] | |
# dist_mat = pairwise_euclidean_distance(group_texts, group_motions) | |
# [bs=32, 32] | |
dist_mat = euclidean_distance_matrix( | |
group_texts, group_motions).nan_to_num() | |
# print(dist_mat[:5]) | |
self.Matching_score += dist_mat.trace() | |
argsmax = torch.argsort(dist_mat, dim=1) | |
top_k_mat += calculate_top_k(argsmax, | |
top_k=self.top_k).sum(axis=0) | |
R_count = count_seq // self.R_size * self.R_size | |
metrics["Matching_score"] = self.Matching_score / R_count | |
for k in range(self.top_k): | |
metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count | |
# Compute r-precision with gt | |
assert count_seq > self.R_size | |
top_k_mat = torch.zeros((self.top_k, )) | |
for i in range(count_seq // self.R_size): | |
# [bs=32, 1*256] | |
group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] | |
# [bs=32, 1*256] | |
group_motions = all_gtmotions[i * self.R_size:(i + 1) * | |
self.R_size] | |
# [bs=32, 32] | |
dist_mat = euclidean_distance_matrix( | |
group_texts, group_motions).nan_to_num() | |
# match score | |
self.gt_Matching_score += dist_mat.trace() | |
argsmax = torch.argsort(dist_mat, dim=1) | |
top_k_mat += calculate_top_k(argsmax, | |
top_k=self.top_k).sum(axis=0) | |
metrics["gt_Matching_score"] = self.gt_Matching_score / R_count | |
for k in range(self.top_k): | |
metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count | |
# tensor -> numpy for FID | |
all_genmotions = all_genmotions.numpy() | |
all_gtmotions = all_gtmotions.numpy() | |
# Compute fid | |
mu, cov = calculate_activation_statistics_np(all_genmotions) | |
gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) | |
metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) | |
# Compute diversity | |
assert count_seq > self.diversity_times | |
metrics["Diversity"] = calculate_diversity_np(all_genmotions, | |
self.diversity_times) | |
metrics["gt_Diversity"] = calculate_diversity_np( | |
all_gtmotions, self.diversity_times) | |
# Reset | |
self.reset() | |
return {**metrics} | |
def update(self, | |
feats_ref: Tensor, | |
feats_rst: Tensor, | |
lengths_ref: List[int], | |
lengths_rst: List[int], | |
word_embs: Tensor = None, | |
pos_ohot: Tensor = None, | |
text_lengths: Tensor = None): | |
self.count += sum(lengths_ref) | |
self.count_seq += len(lengths_ref) | |
# T2m motion encoder | |
align_idx = np.argsort(lengths_ref)[::-1].copy() | |
feats_ref = feats_ref[align_idx] | |
lengths_ref = np.array(lengths_ref)[align_idx] | |
gtmotion_embeddings = self.get_motion_embeddings( | |
feats_ref, lengths_ref) | |
cache = [0] * len(lengths_ref) | |
for i in range(len(lengths_ref)): | |
cache[align_idx[i]] = gtmotion_embeddings[i:i + 1] | |
self.gtmotion_embeddings.extend(cache) | |
align_idx = np.argsort(lengths_rst)[::-1].copy() | |
feats_rst = feats_rst[align_idx] | |
lengths_rst = np.array(lengths_rst)[align_idx] | |
recmotion_embeddings = self.get_motion_embeddings( | |
feats_rst, lengths_rst) | |
cache = [0] * len(lengths_rst) | |
for i in range(len(lengths_rst)): | |
cache[align_idx[i]] = recmotion_embeddings[i:i + 1] | |
self.recmotion_embeddings.extend(cache) | |
# T2m text encoder | |
if self.text: | |
text_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths) | |
text_embeddings = torch.flatten(text_emb, start_dim=1).detach() | |
self.text_embeddings.append(text_embeddings) | |
def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): | |
m_lens = torch.tensor(lengths) | |
m_lens = torch.div(m_lens, | |
self.cfg.DATASET.HUMANML3D.UNIT_LEN, | |
rounding_mode="floor") | |
m_lens = m_lens // self.cfg.DATASET.HUMANML3D.UNIT_LEN | |
mov = self.t2m_moveencoder(feats[..., :-4]).detach() | |
emb = self.t2m_motionencoder(mov, m_lens) | |
# [bs, nlatent*ndim] <= [bs, nlatent, ndim] | |
return torch.flatten(emb, start_dim=1).detach() | |