import torch import numpy as np import logging import os logger = logging.getLogger(__name__) def cast_dict_to_tensors(d, device="cpu"): if isinstance(d, dict): return {k: cast_dict_to_tensors(v, device) for k, v in d.items()} elif isinstance(d, np.ndarray): return torch.from_numpy(d).float().to(device) elif isinstance(d, torch.Tensor): return d.to(device) else: return d def rgba(c: str): from matplotlib import colors as mcolors return mcolors.to_rgba(c) def rgb(c: str): from matplotlib import colors as mcolors return mcolors.to_rgb(c) # split the lightning checkpoint into # seperate state_dict modules for faster loading def extract_ckpt(run_dir, ckpt_name="last"): import torch ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt") extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights") os.makedirs(extracted_path, exist_ok=True) new_path_template = os.path.join(extracted_path, "{}.pt") ckpt_dict = torch.load(ckpt_path) state_dict = ckpt_dict["state_dict"] module_names = list(set([x.split(".")[0] for x in state_dict.keys()])) # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example for module_name in module_names: path = new_path_template.format(module_name) sub_state_dict = { ".".join(x.split(".")[1:]): y.cpu() for x, y in state_dict.items() if x.split(".")[0] == module_name } torch.save(sub_state_dict, path) import os import json from omegaconf import DictConfig, OmegaConf def save_config(cfg: DictConfig) -> str: path = os.path.join(cfg.run_dir, "config.json") config = OmegaConf.to_container(cfg, resolve=True) with open(path, "w") as f: string = json.dumps(config, indent=4) f.write(string) return path def write_json(data, p): import json with open(p, 'w') as fp: json.dump(data, fp, indent=2) def read_json(p): import json with open(p, 'r') as fp: json_contents = json.load(fp) return json_contents def read_config(run_dir: str, return_json=False) -> DictConfig: path = os.path.join(run_dir, "config.json") with open(path, "r") as f: config = json.load(f) if return_json: return config cfg = OmegaConf.create(config) cfg.run_dir = run_dir return cfg