motionfix-demo / retrieval_loader.py
atnikos's picture
fix retrieval placeholders
6837c8b
from gen_utils import extract_ckpt
import hydra
import os
from hydra.utils import instantiate
from gen_utils import read_config
from model_utils import collate_x_dict
import torch
from tmr_model import TMR_textencoder
def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True):
import src.prepare # noqa
import torch
run_dir = cfg.run_dir
model = hydra.utils.instantiate(cfg.model)
# Loading modules one by one
# motion_encoder / text_encoder / text_decoder
pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")
if not os.path.exists(pt_path):
extract_ckpt(run_dir, ckpt_name)
for fname in os.listdir(pt_path):
module_name, ext = os.path.splitext(fname)
if ext != ".pt":
continue
module = getattr(model, module_name, None)
if module is None:
continue
module_path = os.path.join(pt_path, fname)
state_dict = torch.load(module_path)
module.load_state_dict(state_dict)
model = model.to(device)
if eval_mode:
model = model.eval()
return model
# def get_tmr_model(run_dir):
# from gen_utils import read_config
# cfg = read_config(run_dir+'/tmr')
# import ipdb;ipdb.set_trace()
# text_model = instantiate(cfg.data.text_to_token_emb, device='cuda')
# model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda')
# return text_model, model
def get_tmr_model(run_dir):
text_params = {
"latent_dim": 256,
"ff_size": 1024,
"num_layers": 6,
"num_heads": 4,
"activation": "gelu",
"modelpath": "distilbert-base-uncased",
}
"unit_motion_embs"
model = TMR_textencoder(**text_params)
state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt",
map_location='cuda')
# load values for the transformer only
model.load_state_dict(state_dict, strict=False)
model = model.eval()
return model.to('cuda')