bill-jiang's picture
Init
4409449
raw
history blame
No virus
10.5 kB
from typing import List
import os
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import pairwise_euclidean_distance
from .utils import *
from mGPT.config import instantiate_from_config
class TM2TMetrics(Metric):
def __init__(self,
cfg,
dataname='humanml3d',
top_k=3,
R_size=32,
diversity_times=300,
dist_sync_on_step=True,
**kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.cfg = cfg
self.dataname = dataname
self.name = "matching, fid, and diversity scores"
self.top_k = top_k
self.R_size = R_size
self.text = 'lm' in cfg.TRAIN.STAGE and cfg.model.params.task == 't2m'
self.diversity_times = diversity_times
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("count_seq",
default=torch.tensor(0),
dist_reduce_fx="sum")
self.metrics = []
# Matching scores
if self.text:
self.add_state("Matching_score",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("gt_Matching_score",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.Matching_metrics = ["Matching_score", "gt_Matching_score"]
for k in range(1, top_k + 1):
self.add_state(
f"R_precision_top_{str(k)}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.Matching_metrics.append(f"R_precision_top_{str(k)}")
for k in range(1, top_k + 1):
self.add_state(
f"gt_R_precision_top_{str(k)}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}")
self.metrics.extend(self.Matching_metrics)
# Fid
self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.metrics.append("FID")
# Diversity
self.add_state("Diversity",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("gt_Diversity",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.metrics.extend(["Diversity", "gt_Diversity"])
# Chached batches
self.add_state("text_embeddings", default=[], dist_reduce_fx=None)
self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None)
self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None)
# T2M Evaluator
self._get_t2m_evaluator(cfg)
def _get_t2m_evaluator(self, cfg):
"""
load T2M text encoder and motion encoder for evaluating
"""
# init module
self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder)
self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder)
self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder)
# load pretrianed
if self.dataname == "kit":
dataname = "kit"
else:
dataname = "t2m"
t2m_checkpoint = torch.load(os.path.join(
cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"),
map_location="cpu")
self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"])
self.t2m_moveencoder.load_state_dict(
t2m_checkpoint["movement_encoder"])
self.t2m_motionencoder.load_state_dict(
t2m_checkpoint["motion_encoder"])
# freeze params
self.t2m_textencoder.eval()
self.t2m_moveencoder.eval()
self.t2m_motionencoder.eval()
for p in self.t2m_textencoder.parameters():
p.requires_grad = False
for p in self.t2m_moveencoder.parameters():
p.requires_grad = False
for p in self.t2m_motionencoder.parameters():
p.requires_grad = False
@torch.no_grad()
def compute(self, sanity_flag):
count = self.count.item()
count_seq = self.count_seq.item()
# Init metrics dict
metrics = {metric: getattr(self, metric) for metric in self.metrics}
# Jump in sanity check stage
if sanity_flag:
return metrics
# Cat cached batches and shuffle
shuffle_idx = torch.randperm(count_seq)
all_genmotions = torch.cat(self.recmotion_embeddings,
axis=0).cpu()[shuffle_idx, :]
all_gtmotions = torch.cat(self.gtmotion_embeddings,
axis=0).cpu()[shuffle_idx, :]
# Compute text related metrics
if self.text:
all_texts = torch.cat(self.text_embeddings,
axis=0).cpu()[shuffle_idx, :]
# Compute r-precision
assert count_seq > self.R_size
top_k_mat = torch.zeros((self.top_k, ))
for i in range(count_seq // self.R_size):
# [bs=32, 1*256]
group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size]
# [bs=32, 1*256]
group_motions = all_genmotions[i * self.R_size:(i + 1) *
self.R_size]
# dist_mat = pairwise_euclidean_distance(group_texts, group_motions)
# [bs=32, 32]
dist_mat = euclidean_distance_matrix(
group_texts, group_motions).nan_to_num()
# print(dist_mat[:5])
self.Matching_score += dist_mat.trace()
argsmax = torch.argsort(dist_mat, dim=1)
top_k_mat += calculate_top_k(argsmax,
top_k=self.top_k).sum(axis=0)
R_count = count_seq // self.R_size * self.R_size
metrics["Matching_score"] = self.Matching_score / R_count
for k in range(self.top_k):
metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count
# Compute r-precision with gt
assert count_seq > self.R_size
top_k_mat = torch.zeros((self.top_k, ))
for i in range(count_seq // self.R_size):
# [bs=32, 1*256]
group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size]
# [bs=32, 1*256]
group_motions = all_gtmotions[i * self.R_size:(i + 1) *
self.R_size]
# [bs=32, 32]
dist_mat = euclidean_distance_matrix(
group_texts, group_motions).nan_to_num()
# match score
self.gt_Matching_score += dist_mat.trace()
argsmax = torch.argsort(dist_mat, dim=1)
top_k_mat += calculate_top_k(argsmax,
top_k=self.top_k).sum(axis=0)
metrics["gt_Matching_score"] = self.gt_Matching_score / R_count
for k in range(self.top_k):
metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count
# tensor -> numpy for FID
all_genmotions = all_genmotions.numpy()
all_gtmotions = all_gtmotions.numpy()
# Compute fid
mu, cov = calculate_activation_statistics_np(all_genmotions)
gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions)
metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov)
# Compute diversity
assert count_seq > self.diversity_times
metrics["Diversity"] = calculate_diversity_np(all_genmotions,
self.diversity_times)
metrics["gt_Diversity"] = calculate_diversity_np(
all_gtmotions, self.diversity_times)
# Reset
self.reset()
return {**metrics}
@torch.no_grad()
def update(self,
feats_ref: Tensor,
feats_rst: Tensor,
lengths_ref: List[int],
lengths_rst: List[int],
word_embs: Tensor = None,
pos_ohot: Tensor = None,
text_lengths: Tensor = None):
self.count += sum(lengths_ref)
self.count_seq += len(lengths_ref)
# T2m motion encoder
align_idx = np.argsort(lengths_ref)[::-1].copy()
feats_ref = feats_ref[align_idx]
lengths_ref = np.array(lengths_ref)[align_idx]
gtmotion_embeddings = self.get_motion_embeddings(
feats_ref, lengths_ref)
cache = [0] * len(lengths_ref)
for i in range(len(lengths_ref)):
cache[align_idx[i]] = gtmotion_embeddings[i:i + 1]
self.gtmotion_embeddings.extend(cache)
align_idx = np.argsort(lengths_rst)[::-1].copy()
feats_rst = feats_rst[align_idx]
lengths_rst = np.array(lengths_rst)[align_idx]
recmotion_embeddings = self.get_motion_embeddings(
feats_rst, lengths_rst)
cache = [0] * len(lengths_rst)
for i in range(len(lengths_rst)):
cache[align_idx[i]] = recmotion_embeddings[i:i + 1]
self.recmotion_embeddings.extend(cache)
# T2m text encoder
if self.text:
text_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths)
text_embeddings = torch.flatten(text_emb, start_dim=1).detach()
self.text_embeddings.append(text_embeddings)
def get_motion_embeddings(self, feats: Tensor, lengths: List[int]):
m_lens = torch.tensor(lengths)
m_lens = torch.div(m_lens,
self.cfg.DATASET.HUMANML3D.UNIT_LEN,
rounding_mode="floor")
m_lens = m_lens // self.cfg.DATASET.HUMANML3D.UNIT_LEN
mov = self.t2m_moveencoder(feats[..., :-4]).detach()
emb = self.t2m_motionencoder(mov, m_lens)
# [bs, nlatent*ndim] <= [bs, nlatent, ndim]
return torch.flatten(emb, start_dim=1).detach()