Spaces:
Running
Running
import torch | |
import numpy as np | |
import logging | |
import os | |
logger = logging.getLogger(__name__) | |
def cast_dict_to_tensors(d, device="cpu"): | |
if isinstance(d, dict): | |
return {k: cast_dict_to_tensors(v, device) for k, v in d.items()} | |
elif isinstance(d, np.ndarray): | |
return torch.from_numpy(d).float().to(device) | |
elif isinstance(d, torch.Tensor): | |
return d.to(device) | |
else: | |
return d | |
def rgba(c: str): | |
from matplotlib import colors as mcolors | |
return mcolors.to_rgba(c) | |
def rgb(c: str): | |
from matplotlib import colors as mcolors | |
return mcolors.to_rgb(c) | |
# split the lightning checkpoint into | |
# seperate state_dict modules for faster loading | |
def extract_ckpt(run_dir, ckpt_name="last"): | |
import torch | |
ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt") | |
extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights") | |
os.makedirs(extracted_path, exist_ok=True) | |
new_path_template = os.path.join(extracted_path, "{}.pt") | |
ckpt_dict = torch.load(ckpt_path) | |
state_dict = ckpt_dict["state_dict"] | |
module_names = list(set([x.split(".")[0] for x in state_dict.keys()])) | |
# should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example | |
for module_name in module_names: | |
path = new_path_template.format(module_name) | |
sub_state_dict = { | |
".".join(x.split(".")[1:]): y.cpu() | |
for x, y in state_dict.items() | |
if x.split(".")[0] == module_name | |
} | |
torch.save(sub_state_dict, path) | |
import os | |
import json | |
from omegaconf import DictConfig, OmegaConf | |
def save_config(cfg: DictConfig) -> str: | |
path = os.path.join(cfg.run_dir, "config.json") | |
config = OmegaConf.to_container(cfg, resolve=True) | |
with open(path, "w") as f: | |
string = json.dumps(config, indent=4) | |
f.write(string) | |
return path | |
def write_json(data, p): | |
import json | |
with open(p, 'w') as fp: | |
json.dump(data, fp, indent=2) | |
def read_json(p): | |
import json | |
with open(p, 'r') as fp: | |
json_contents = json.load(fp) | |
return json_contents | |
def read_config(run_dir: str, return_json=False) -> DictConfig: | |
path = os.path.join(run_dir, "config.json") | |
with open(path, "r") as f: | |
config = json.load(f) | |
if return_json: | |
return config | |
cfg = OmegaConf.create(config) | |
cfg.run_dir = run_dir | |
return cfg | |