Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: | |
| """ | |
| Set requires_grad flag for all parameters in a model. | |
| """ | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| def set_seed(seed): | |
| random.seed(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| def str_to_dtype(x: str): | |
| if x == "fp32": | |
| return torch.float32 | |
| elif x == "fp16": | |
| return torch.float16 | |
| elif x == "bf16": | |
| return torch.bfloat16 | |
| else: | |
| raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}") | |
| def batch_func(func, *args): | |
| """ | |
| Apply a function to each element of a batch. | |
| """ | |
| batch = [] | |
| for arg in args: | |
| if isinstance(arg, torch.Tensor) and arg.shape[0] == 2: | |
| batch.append(func(arg)) | |
| else: | |
| batch.append(arg) | |
| return batch | |
| def merge_args(args1, args2): | |
| """ | |
| Merge two argparse Namespace objects. | |
| """ | |
| if args2 is None: | |
| return args1 | |
| for k in args2._content.keys(): | |
| if k in args1.__dict__: | |
| v = getattr(args2, k) | |
| if isinstance(v, ListConfig) or isinstance(v, DictConfig): | |
| v = OmegaConf.to_object(v) | |
| setattr(args1, k, v) | |
| else: | |
| raise RuntimeError(f"Unknown argument {k}") | |
| return args1 | |
| def all_exists(paths): | |
| return all(os.path.exists(path) for path in paths) | |
| def save_video(video, output_path, fps): | |
| """ | |
| Save a video to disk. | |
| """ | |
| if dist.is_initialized() and dist.get_rank() != 0: | |
| return | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| imageio.mimwrite(output_path, video, fps=fps) | |