Spaces:
Running
Running
from gen_utils import extract_ckpt | |
import hydra | |
import os | |
from hydra.utils import instantiate | |
from gen_utils import read_config | |
from model_utils import collate_x_dict | |
import torch | |
from tmr_model import TMR_textencoder | |
def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True): | |
import src.prepare # noqa | |
import torch | |
run_dir = cfg.run_dir | |
model = hydra.utils.instantiate(cfg.model) | |
# Loading modules one by one | |
# motion_encoder / text_encoder / text_decoder | |
pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") | |
if not os.path.exists(pt_path): | |
extract_ckpt(run_dir, ckpt_name) | |
for fname in os.listdir(pt_path): | |
module_name, ext = os.path.splitext(fname) | |
if ext != ".pt": | |
continue | |
module = getattr(model, module_name, None) | |
if module is None: | |
continue | |
module_path = os.path.join(pt_path, fname) | |
state_dict = torch.load(module_path) | |
module.load_state_dict(state_dict) | |
model = model.to(device) | |
if eval_mode: | |
model = model.eval() | |
return model | |
# def get_tmr_model(run_dir): | |
# from gen_utils import read_config | |
# cfg = read_config(run_dir+'/tmr') | |
# import ipdb;ipdb.set_trace() | |
# text_model = instantiate(cfg.data.text_to_token_emb, device='cuda') | |
# model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda') | |
# return text_model, model | |
def get_tmr_model(run_dir): | |
text_params = { | |
"latent_dim": 256, | |
"ff_size": 1024, | |
"num_layers": 6, | |
"num_heads": 4, | |
"activation": "gelu", | |
"modelpath": "distilbert-base-uncased", | |
} | |
"unit_motion_embs" | |
model = TMR_textencoder(**text_params) | |
state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt", | |
map_location='cuda') | |
# load values for the transformer only | |
model.load_state_dict(state_dict, strict=False) | |
model = model.eval() | |
return model.to('cuda') | |