motionfix-demo / gen_utils.py
atnikos's picture
first tmr retrieval efffort
38c4910
import torch
import numpy as np
import logging
import os
logger = logging.getLogger(__name__)
def cast_dict_to_tensors(d, device="cpu"):
if isinstance(d, dict):
return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
elif isinstance(d, np.ndarray):
return torch.from_numpy(d).float().to(device)
elif isinstance(d, torch.Tensor):
return d.to(device)
else:
return d
def rgba(c: str):
from matplotlib import colors as mcolors
return mcolors.to_rgba(c)
def rgb(c: str):
from matplotlib import colors as mcolors
return mcolors.to_rgb(c)
# split the lightning checkpoint into
# seperate state_dict modules for faster loading
def extract_ckpt(run_dir, ckpt_name="last"):
import torch
ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt")
extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights")
os.makedirs(extracted_path, exist_ok=True)
new_path_template = os.path.join(extracted_path, "{}.pt")
ckpt_dict = torch.load(ckpt_path)
state_dict = ckpt_dict["state_dict"]
module_names = list(set([x.split(".")[0] for x in state_dict.keys()]))
# should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example
for module_name in module_names:
path = new_path_template.format(module_name)
sub_state_dict = {
".".join(x.split(".")[1:]): y.cpu()
for x, y in state_dict.items()
if x.split(".")[0] == module_name
}
torch.save(sub_state_dict, path)
import os
import json
from omegaconf import DictConfig, OmegaConf
def save_config(cfg: DictConfig) -> str:
path = os.path.join(cfg.run_dir, "config.json")
config = OmegaConf.to_container(cfg, resolve=True)
with open(path, "w") as f:
string = json.dumps(config, indent=4)
f.write(string)
return path
def write_json(data, p):
import json
with open(p, 'w') as fp:
json.dump(data, fp, indent=2)
def read_json(p):
import json
with open(p, 'r') as fp:
json_contents = json.load(fp)
return json_contents
def read_config(run_dir: str, return_json=False) -> DictConfig:
path = os.path.join(run_dir, "config.json")
with open(path, "r") as f:
config = json.load(f)
if return_json:
return config
cfg = OmegaConf.create(config)
cfg.run_dir = run_dir
return cfg