from math import ceil from pathlib import Path import os import re from beartype.typing import Tuple from einops import rearrange, repeat from toolz import valmap import torch from torch import Tensor from torch.nn import Module import torch.nn.functional as F import yaml from .typing import typecheck def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def path_mkdir(path): path = Path(path) path.mkdir(parents=True, exist_ok=True) return path def load_yaml(path, default_path=None): path = path_exists(path) with open(path, mode='r') as fp: cfg_s = yaml.load(fp, Loader=yaml.FullLoader) if default_path is not None: default_path = path_exists(default_path) with open(default_path, mode='r') as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) else: # try current dir default default_path = path.parent / 'default.yml' if default_path.exists(): with open(default_path, mode='r') as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) else: cfg = {} update_recursive(cfg, cfg_s) return cfg def dump_yaml(cfg, path): with open(path, mode='w') as f: return yaml.safe_dump(cfg, f) def update_recursive(dict1, dict2): ''' Update two config dictionaries recursively. Args: dict1 (dict): first dictionary to be updated dict2 (dict): second dictionary which entries should be used ''' for k, v in dict2.items(): if k not in dict1: dict1[k] = dict() if isinstance(v, dict): update_recursive(dict1[k], v) else: dict1[k] = v def load_latest_checkpoint(checkpoint_dir): pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt") max_epoch = -1 latest_checkpoint = None for filename in os.listdir(checkpoint_dir): match = pattern.match(filename) if match: num_epoch = int(match.group(1)) if num_epoch > max_epoch: max_epoch = num_epoch latest_checkpoint = checkpoint_dir / filename if not exists(latest_checkpoint): raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") checkpoint = torch.load(latest_checkpoint) return checkpoint, latest_checkpoint def torch_to(inp, device, non_blocking=False): nb = non_blocking # set to True when doing distributed jobs if isinstance(inp, torch.Tensor): return inp.to(device, non_blocking=nb) elif isinstance(inp, (list, tuple)): return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)) elif isinstance(inp, dict): return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp) else: raise NotImplementedError # helper functions def exists(v): return v is not None def default(v, d): return v if exists(v) else d def first(it): return it[0] def identity(t, *args, **kwargs): return t def divisible_by(num, den): return (num % den) == 0 def is_odd(n): return not divisible_by(n, 2) def is_empty(x): return len(x) == 0 def is_tensor_empty(t: Tensor): return t.numel() == 0 def set_module_requires_grad_( module: Module, requires_grad: bool ): for param in module.parameters(): param.requires_grad = requires_grad def l1norm(t): return F.normalize(t, dim = -1, p = 1) def l2norm(t): return F.normalize(t, dim = -1, p = 2) def safe_cat(tensors, dim): tensors = [*filter(exists, tensors)] if len(tensors) == 0: return None elif len(tensors) == 1: return first(tensors) return torch.cat(tensors, dim = dim) def pad_at_dim(t, padding, dim = -1, value = 0): ndim = t.ndim right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1) zeros = (0, 0) * right_dims return F.pad(t, (*zeros, *padding), value = value) def pad_to_length(t, length, dim = -1, value = 0, right = True): curr_length = t.shape[dim] remainder = length - curr_length if remainder <= 0: return t padding = (0, remainder) if right else (remainder, 0) return pad_at_dim(t, padding, dim = dim, value = value) def masked_mean(tensor, mask, dim = -1, eps = 1e-5): if not exists(mask): return tensor.mean(dim = dim) mask = rearrange(mask, '... -> ... 1') tensor = tensor.masked_fill(~mask, 0.) total_el = mask.sum(dim = dim) num = tensor.sum(dim = dim) den = total_el.float().clamp(min = eps) mean = num / den mean = mean.masked_fill(total_el == 0, 0.) return mean def cycle(dl): while True: for data in dl: yield data def maybe_del(d: dict, *keys): for key in keys: if key not in d: continue del d[key] # tensor helper functions @typecheck def discretize( t: Tensor, *, continuous_range: Tuple[float, float], num_discrete: int = 128 ) -> Tensor: lo, hi = continuous_range assert hi > lo t = (t - lo) / (hi - lo) t *= num_discrete t -= 0.5 return t.round().long().clamp(min = 0, max = num_discrete - 1) @typecheck def undiscretize( t: Tensor, *, continuous_range = Tuple[float, float], num_discrete: int = 128 ) -> Tensor: lo, hi = continuous_range assert hi > lo t = t.float() t += 0.5 t /= num_discrete return t * (hi - lo) + lo @typecheck def gaussian_blur_1d( t: Tensor, *, sigma: float = 1., kernel_size: int = 5 ) -> Tensor: _, _, channels, device, dtype = *t.shape, t.device, t.dtype width = int(ceil(sigma * kernel_size)) width += (width + 1) % 2 half_width = width // 2 distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device) gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2)) gaussian = l1norm(gaussian) kernel = repeat(gaussian, 'n -> c 1 n', c = channels) t = rearrange(t, 'b n c -> b c n') out = F.conv1d(t, kernel, padding = half_width, groups = channels) return rearrange(out, 'b c n -> b n c') @typecheck def scatter_mean( tgt: Tensor, indices: Tensor, src = Tensor, *, dim: int = -1, eps: float = 1e-5 ): """ todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_ """ num = tgt.scatter_add(dim, indices, src) den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src)) return num / den.clamp(min = eps) def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) elif prob == 0: return torch.zeros(shape, device = device, dtype = torch.bool) else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob