File size: 1,930 Bytes
07c6a04 a28e78a 07c6a04 a28e78a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import random
import imageio
import numpy as np
import torch
import torch.distributed as dist
from omegaconf import DictConfig, ListConfig, OmegaConf
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def set_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def str_to_dtype(x: str):
if x == "fp32":
return torch.float32
elif x == "fp16":
return torch.float16
elif x == "bf16":
return torch.bfloat16
else:
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
def batch_func(func, *args):
"""
Apply a function to each element of a batch.
"""
batch = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
batch.append(func(arg))
else:
batch.append(arg)
return batch
def merge_args(args1, args2):
"""
Merge two argparse Namespace objects.
"""
if args2 is None:
return args1
for k in args2._content.keys():
if k in args1.__dict__:
v = getattr(args2, k)
if isinstance(v, ListConfig) or isinstance(v, DictConfig):
v = OmegaConf.to_object(v)
setattr(args1, k, v)
else:
raise RuntimeError(f"Unknown argument {k}")
return args1
def all_exists(paths):
return all(os.path.exists(path) for path in paths)
def save_video(video, output_path, fps):
"""
Save a video to disk.
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
os.makedirs(os.path.dirname(output_path), exist_ok=True)
imageio.mimwrite(output_path, video, fps=fps)
|