|
from typing import List |
|
import os |
|
import torch |
|
from torch import Tensor |
|
from torchmetrics import Metric |
|
from .utils import * |
|
from bert_score import score as score_bert |
|
import spacy |
|
from mGPT.config import instantiate_from_config |
|
|
|
class M2TMetrics(Metric): |
|
|
|
def __init__(self, |
|
cfg, |
|
w_vectorizer, |
|
dataname='humanml3d', |
|
top_k=3, |
|
bleu_k=4, |
|
R_size=32, |
|
max_text_len=40, |
|
diversity_times=300, |
|
dist_sync_on_step=True, |
|
unit_length=4, |
|
**kwargs): |
|
super().__init__(dist_sync_on_step=dist_sync_on_step) |
|
|
|
self.cfg = cfg |
|
self.dataname = dataname |
|
self.w_vectorizer = w_vectorizer |
|
self.name = "matching, fid, and diversity scores" |
|
|
|
self.max_text_len = max_text_len |
|
self.top_k = top_k |
|
self.bleu_k = bleu_k |
|
self.R_size = R_size |
|
self.diversity_times = diversity_times |
|
self.unit_length = unit_length |
|
|
|
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") |
|
self.add_state("count_seq", |
|
default=torch.tensor(0), |
|
dist_reduce_fx="sum") |
|
|
|
self.metrics = [] |
|
|
|
|
|
self.add_state("Matching_score", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum") |
|
self.add_state("gt_Matching_score", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum") |
|
self.Matching_metrics = ["Matching_score", "gt_Matching_score"] |
|
for k in range(1, top_k + 1): |
|
self.add_state( |
|
f"R_precision_top_{str(k)}", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum", |
|
) |
|
self.Matching_metrics.append(f"R_precision_top_{str(k)}") |
|
for k in range(1, top_k + 1): |
|
self.add_state( |
|
f"gt_R_precision_top_{str(k)}", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum", |
|
) |
|
self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") |
|
|
|
self.metrics.extend(self.Matching_metrics) |
|
|
|
|
|
for k in range(1, top_k + 1): |
|
self.add_state( |
|
f"Bleu_{str(k)}", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum", |
|
) |
|
self.metrics.append(f"Bleu_{str(k)}") |
|
|
|
self.add_state("ROUGE_L", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum") |
|
self.metrics.append("ROUGE_L") |
|
|
|
self.add_state("CIDEr", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum") |
|
self.metrics.append("CIDEr") |
|
|
|
|
|
self.pred_texts = [] |
|
self.gt_texts = [] |
|
self.add_state("predtext_embeddings", default=[]) |
|
self.add_state("gttext_embeddings", default=[]) |
|
self.add_state("gtmotion_embeddings", default=[]) |
|
|
|
|
|
self._get_t2m_evaluator(cfg) |
|
|
|
self.nlp = spacy.load('en_core_web_sm') |
|
|
|
if self.cfg.model.params.task == 'm2t': |
|
from nlgmetricverse import NLGMetricverse, load_metric |
|
metrics = [ |
|
load_metric("bleu", resulting_name="bleu_1", compute_kwargs={"max_order": 1}), |
|
load_metric("bleu", resulting_name="bleu_4", compute_kwargs={"max_order": 4}), |
|
load_metric("rouge"), |
|
load_metric("cider"), |
|
] |
|
self.nlg_evaluator = NLGMetricverse(metrics) |
|
|
|
def _get_t2m_evaluator(self, cfg): |
|
""" |
|
load T2M text encoder and motion encoder for evaluating |
|
""" |
|
|
|
self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) |
|
self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) |
|
self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) |
|
|
|
|
|
|
|
if self.dataname == "kit": |
|
dataname = "kit" |
|
else: |
|
dataname = "t2m" |
|
|
|
t2m_checkpoint = torch.load(os.path.join( |
|
cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"), |
|
map_location='cpu') |
|
self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) |
|
self.t2m_moveencoder.load_state_dict( |
|
t2m_checkpoint["movement_encoder"]) |
|
self.t2m_motionencoder.load_state_dict( |
|
t2m_checkpoint["motion_encoder"]) |
|
|
|
|
|
self.t2m_textencoder.eval() |
|
self.t2m_moveencoder.eval() |
|
self.t2m_motionencoder.eval() |
|
for p in self.t2m_textencoder.parameters(): |
|
p.requires_grad = False |
|
for p in self.t2m_moveencoder.parameters(): |
|
p.requires_grad = False |
|
for p in self.t2m_motionencoder.parameters(): |
|
p.requires_grad = False |
|
|
|
def _process_text(self, sentence): |
|
sentence = sentence.replace('-', '') |
|
doc = self.nlp(sentence) |
|
word_list = [] |
|
pos_list = [] |
|
for token in doc: |
|
word = token.text |
|
if not word.isalpha(): |
|
continue |
|
if (token.pos_ == 'NOUN' |
|
or token.pos_ == 'VERB') and (word != 'left'): |
|
word_list.append(token.lemma_) |
|
else: |
|
word_list.append(word) |
|
pos_list.append(token.pos_) |
|
return word_list, pos_list |
|
|
|
def _get_text_embeddings(self, texts): |
|
word_embs = [] |
|
pos_ohot = [] |
|
text_lengths = [] |
|
for i, sentence in enumerate(texts): |
|
word_list, pos_list = self._process_text(sentence.strip()) |
|
t_tokens = [ |
|
'%s/%s' % (word_list[i], pos_list[i]) |
|
for i in range(len(word_list)) |
|
] |
|
|
|
if len(t_tokens) < self.max_text_len: |
|
|
|
tokens = ['sos/OTHER'] + t_tokens + ['eos/OTHER'] |
|
sent_len = len(tokens) |
|
tokens = tokens + ['unk/OTHER' |
|
] * (self.max_text_len + 2 - sent_len) |
|
else: |
|
|
|
tokens = t_tokens[:self.max_text_len] |
|
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] |
|
sent_len = len(tokens) |
|
pos_one_hots = [] |
|
word_embeddings = [] |
|
for token in tokens: |
|
word_emb, pos_oh = self.w_vectorizer[token] |
|
pos_one_hots.append(torch.tensor(pos_oh).float()[None]) |
|
word_embeddings.append(torch.tensor(word_emb).float()[None]) |
|
text_lengths.append(sent_len) |
|
pos_ohot.append(torch.cat(pos_one_hots, dim=0)[None]) |
|
word_embs.append(torch.cat(word_embeddings, dim=0)[None]) |
|
|
|
word_embs = torch.cat(word_embs, dim=0).to(self.Matching_score) |
|
pos_ohot = torch.cat(pos_ohot, dim=0).to(self.Matching_score) |
|
text_lengths = torch.tensor(text_lengths).to(self.Matching_score) |
|
|
|
align_idx = np.argsort(text_lengths.data.tolist())[::-1].copy() |
|
|
|
|
|
text_embeddings = self.t2m_textencoder(word_embs[align_idx], |
|
pos_ohot[align_idx], |
|
text_lengths[align_idx]) |
|
|
|
original_text_embeddings = text_embeddings.clone() |
|
|
|
for idx, sort in enumerate(align_idx): |
|
original_text_embeddings[sort] = text_embeddings[idx] |
|
|
|
return original_text_embeddings |
|
|
|
@torch.no_grad() |
|
def compute(self, sanity_flag): |
|
count = self.count.item() |
|
count_seq = self.count_seq.item() |
|
|
|
|
|
metrics = {metric: getattr(self, metric) for metric in self.metrics} |
|
|
|
|
|
if sanity_flag: |
|
return metrics |
|
|
|
|
|
shuffle_idx = torch.randperm(count_seq) |
|
all_motions = torch.cat(self.gtmotion_embeddings, |
|
axis=0).cpu()[shuffle_idx, :] |
|
all_gttexts = torch.cat(self.gttext_embeddings, |
|
axis=0).cpu()[shuffle_idx, :] |
|
all_predtexts = torch.cat(self.predtext_embeddings, |
|
axis=0).cpu()[shuffle_idx, :] |
|
|
|
print("Computing metrics...") |
|
|
|
|
|
assert count_seq >= self.R_size |
|
top_k_mat = torch.zeros((self.top_k, )) |
|
for i in range(count_seq // self.R_size): |
|
|
|
group_texts = all_predtexts[i * self.R_size:(i + 1) * self.R_size] |
|
|
|
group_motions = all_motions[i * self.R_size:(i + 1) * self.R_size] |
|
|
|
dist_mat = euclidean_distance_matrix(group_texts, |
|
group_motions).nan_to_num() |
|
|
|
self.Matching_score += dist_mat.trace() |
|
argsmax = torch.argsort(dist_mat, dim=1) |
|
top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) |
|
|
|
R_count = count_seq // self.R_size * self.R_size |
|
metrics["Matching_score"] = self.Matching_score / R_count |
|
for k in range(self.top_k): |
|
metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count |
|
|
|
|
|
assert count_seq >= self.R_size |
|
top_k_mat = torch.zeros((self.top_k, )) |
|
for i in range(count_seq // self.R_size): |
|
|
|
group_texts = all_gttexts[i * self.R_size:(i + 1) * self.R_size] |
|
|
|
group_motions = all_motions[i * self.R_size:(i + 1) * self.R_size] |
|
|
|
dist_mat = euclidean_distance_matrix(group_texts, |
|
group_motions).nan_to_num() |
|
|
|
self.gt_Matching_score += dist_mat.trace() |
|
argsmax = torch.argsort(dist_mat, dim=1) |
|
top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) |
|
metrics["gt_Matching_score"] = self.gt_Matching_score / R_count |
|
for k in range(self.top_k): |
|
metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count |
|
|
|
|
|
scores = self.nlg_evaluator(predictions=self.pred_texts, |
|
references=self.gt_texts) |
|
for k in range(1, self.bleu_k + 1): |
|
metrics[f"Bleu_{str(k)}"] = torch.tensor(scores[f'bleu_{str(k)}'], |
|
device=self.device) |
|
|
|
metrics["ROUGE_L"] = torch.tensor(scores["rouge"]["rougeL"], |
|
device=self.device) |
|
metrics["CIDEr"] = torch.tensor(scores["cider"]['score'],device=self.device) |
|
|
|
|
|
P, R, F1 = score_bert(self.pred_texts, |
|
self.gt_texts, |
|
lang='en', |
|
rescale_with_baseline=True, |
|
idf=True, |
|
device=self.device, |
|
verbose=False) |
|
|
|
metrics["Bert_F1"] = F1.mean() |
|
|
|
|
|
self.reset() |
|
self.gt_texts = [] |
|
self.pred_texts = [] |
|
|
|
return {**metrics} |
|
|
|
@torch.no_grad() |
|
def update(self, |
|
feats_ref: Tensor, |
|
pred_texts: List[str], |
|
gt_texts: List[str], |
|
lengths: List[int], |
|
word_embs: Tensor = None, |
|
pos_ohot: Tensor = None, |
|
text_lengths: Tensor = None): |
|
|
|
self.count += sum(lengths) |
|
self.count_seq += len(lengths) |
|
|
|
|
|
m_lens = torch.tensor(lengths, device=feats_ref.device) |
|
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() |
|
feats_ref = feats_ref[align_idx] |
|
m_lens = m_lens[align_idx] |
|
m_lens = torch.div(m_lens, |
|
self.cfg.DATASET.HUMANML3D.UNIT_LEN, |
|
rounding_mode="floor") |
|
ref_mov = self.t2m_moveencoder(feats_ref[..., :-4]).detach() |
|
m_lens = m_lens // self.unit_length |
|
ref_emb = self.t2m_motionencoder(ref_mov, m_lens) |
|
gtmotion_embeddings = torch.flatten(ref_emb, start_dim=1).detach() |
|
self.gtmotion_embeddings.append(gtmotion_embeddings) |
|
|
|
|
|
gttext_emb = self.t2m_textencoder(word_embs, pos_ohot, |
|
text_lengths)[align_idx] |
|
gttext_embeddings = torch.flatten(gttext_emb, start_dim=1).detach() |
|
predtext_emb = self._get_text_embeddings(pred_texts)[align_idx] |
|
predtext_embeddings = torch.flatten(predtext_emb, start_dim=1).detach() |
|
|
|
self.gttext_embeddings.append(gttext_embeddings) |
|
self.predtext_embeddings.append(predtext_embeddings) |
|
|
|
self.pred_texts.extend(pred_texts) |
|
self.gt_texts.extend(gt_texts) |
|
|