TMR / load.py
Mathis Petrovich
device
5e4fa5e
raw history blame
No virus
1.32 kB
import os
import orjson
import torch
import numpy as np
from model import TMR_textencoder
EMBS = "data/unit_motion_embs"
def load_json(path):
with open(path, "rb") as ff:
return orjson.loads(ff.read())
def load_keyids(split):
path = os.path.join(EMBS, f"{split}.keyids")
with open(path) as ff:
keyids = np.array([x.strip() for x in ff.readlines()])
return keyids
def load_keyids_splits(splits):
return {split: load_keyids(split) for split in splits}
def load_unit_motion_embs(split, device):
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
tensor = torch.from_numpy(np.load(path)).to(device)
return tensor
def load_unit_motion_embs_splits(splits, device):
return {split: load_unit_motion_embs(split, device) for split in splits}
def load_model(device):
text_params = {
"latent_dim": 256,
"ff_size": 1024,
"num_layers": 6,
"num_heads": 4,
"activation": "gelu",
"modelpath": "distilbert-base-uncased",
}
"unit_motion_embs"
model = TMR_textencoder(**text_params)
state_dict = torch.load("data/textencoder.pt", map_location=device)
# load values for the transformer only
model.load_state_dict(state_dict, strict=False)
model = model.eval()
return model.to(device)