File size: 2,016 Bytes
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from gen_utils import extract_ckpt
import hydra
import os
from hydra.utils import instantiate
from gen_utils import read_config
from model_utils import collate_x_dict
import torch
from tmr_model import TMR_textencoder


def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True):
    import src.prepare  # noqa
    import torch

    run_dir = cfg.run_dir
    model = hydra.utils.instantiate(cfg.model)

    # Loading modules one by one
    # motion_encoder / text_encoder / text_decoder
    pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")

    if not os.path.exists(pt_path):
        extract_ckpt(run_dir, ckpt_name)

    for fname in os.listdir(pt_path):
        module_name, ext = os.path.splitext(fname)
        if ext != ".pt":
            continue

        module = getattr(model, module_name, None)
        if module is None:
            continue

        module_path = os.path.join(pt_path, fname)
        state_dict = torch.load(module_path)
        module.load_state_dict(state_dict)
    model = model.to(device)
    if eval_mode:
        model = model.eval()
    return model

# def get_tmr_model(run_dir):
#     from gen_utils import read_config
#     cfg = read_config(run_dir+'/tmr')
#     import ipdb;ipdb.set_trace()
#     text_model = instantiate(cfg.data.text_to_token_emb, device='cuda')
#     model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda')
#     return text_model, model

def get_tmr_model(run_dir):
    text_params = {
        "latent_dim": 256,
        "ff_size": 1024,
        "num_layers": 6,
        "num_heads": 4,
        "activation": "gelu",
        "modelpath": "distilbert-base-uncased",
    }
    "unit_motion_embs"
    model = TMR_textencoder(**text_params)
    state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt", 
                            map_location='cuda')
    # load values for the transformer only
    model.load_state_dict(state_dict, strict=False)
    model = model.eval()
    return model.to('cuda')