Spaces:
Running
Running
File size: 1,320 Bytes
83f52e6 5e4fa5e 83f52e6 5e4fa5e 83f52e6 5e4fa5e 83f52e6 5e4fa5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
import orjson
import torch
import numpy as np
from model import TMR_textencoder
EMBS = "data/unit_motion_embs"
def load_json(path):
with open(path, "rb") as ff:
return orjson.loads(ff.read())
def load_keyids(split):
path = os.path.join(EMBS, f"{split}.keyids")
with open(path) as ff:
keyids = np.array([x.strip() for x in ff.readlines()])
return keyids
def load_keyids_splits(splits):
return {split: load_keyids(split) for split in splits}
def load_unit_motion_embs(split, device):
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
tensor = torch.from_numpy(np.load(path)).to(device)
return tensor
def load_unit_motion_embs_splits(splits, device):
return {split: load_unit_motion_embs(split, device) for split in splits}
def load_model(device):
text_params = {
"latent_dim": 256,
"ff_size": 1024,
"num_layers": 6,
"num_heads": 4,
"activation": "gelu",
"modelpath": "distilbert-base-uncased",
}
"unit_motion_embs"
model = TMR_textencoder(**text_params)
state_dict = torch.load("data/textencoder.pt", map_location=device)
# load values for the transformer only
model.load_state_dict(state_dict, strict=False)
model = model.eval()
return model.to(device)
|