move-it / eval_t2m.py
zirobtc's picture
Upload folder using huggingface_hub
d2a17a9 verified
import os
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import json
import sys
from models.llama_model import LLaMAHF, LLaMAHFConfig
import options.option_transformer as option_trans
import utils.utils_model as utils_model
import utils.eval_trans as eval_trans
from humanml3d_272 import dataset_eval_tae
import models.tae as tae
import warnings
warnings.filterwarnings('ignore')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.chdir('Evaluator_272')
sys.path.insert(0, os.getcwd())
comp_device = torch.device('cuda')
##### ---- Exp dirs ---- #####
args = option_trans.get_args_parser()
torch.manual_seed(args.seed)
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
os.makedirs(args.out_dir, exist_ok = True)
##### ---- Logger ---- #####
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32)
##### ---- Network ---- #####
from sentence_transformers import SentenceTransformer
t5_model = SentenceTransformer('../sentencet5-xxl/')
t5_model.eval()
for p in t5_model.parameters():
p.requires_grad = False
tokenize_model = t5_model
# Causal TAE
clip_range = [-30,20]
net = tae.Causal_HumanTAE(
hidden_size=args.hidden_size,
down_t=args.down_t,
stride_t=args.stride_t,
depth=args.depth,
dilation_growth_rate=args.dilation_growth_rate,
activation='relu',
latent_dim=args.latent_dim,
clip_range=clip_range
)
config = LLaMAHFConfig.from_name('Normal_size')
config.block_size = 78
trans_encoder = LLaMAHF(config, args.num_diffusion_head_layers, args.latent_dim, comp_device)
print('loading checkpoint from {}'.format(args.resume_pth))
ckpt = torch.load(args.resume_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.eval()
net.to(comp_device)
if args.resume_trans is not None:
print('loading transformer checkpoint from {}'.format(args.resume_trans))
ckpt = torch.load(args.resume_trans, map_location='cpu')
new_ckpt_trans = {}
for key in ckpt['trans'].keys():
if key.split('.')[0]=='module':
new_key = '.'.join(key.split('.')[1:])
else:
new_key = key
new_ckpt_trans[new_key] = ckpt['trans'][key]
trans_encoder.load_state_dict(new_ckpt_trans, strict=True)
trans_encoder.eval()
trans_encoder.to(comp_device)
# load evaluator:
import torch
from transformers import AutoTokenizer, AutoModel
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder
from collections import OrderedDict
modelpath = 'distilbert-base-uncased'
textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4, latent_dim=256)
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, latent_dim=256, max_len=300)
ckpt_path = 'epoch=99.ckpt'
print(f'Loading evaluator checkpoint from {ckpt_path}')
ckpt = torch.load(ckpt_path)
# load textencoder
textencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "textencoder":
name = k.replace("textencoder.", "")
textencoder_ckpt[name] = v
textencoder.load_state_dict(textencoder_ckpt, strict=True)
textencoder.eval()
textencoder.to(comp_device)
# load motionencoder
motionencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "motionencoder":
name = k.replace("motionencoder.", "")
motionencoder_ckpt[name] = v
motionencoder.load_state_dict(motionencoder_ckpt, strict=True)
motionencoder.eval()
motionencoder.to(comp_device)
#--------------------------------
evaluator = [textencoder, motionencoder]
fid = []
div = []
top1 = []
top2 = []
top3 = []
matching = []
mpjpe = []
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, logger = eval_trans.evaluation_transformer_272_single(val_loader, net, trans_encoder, tokenize_model, logger, evaluator, 4.0)
fid.append(best_fid)
div.append(best_div)
top1.append(best_top1)
top2.append(best_top2)
top3.append(best_top3)
matching.append(best_matching)
logger.info('final result:')
logger.info(f'fid: {fid}')
logger.info(f'div: {div}')
logger.info(f'top1: {top1}')
logger.info(f'top2: {top2}')
logger.info(f'top3: {top3}')
logger.info(f'MM-dist (matching score) : {matching}')