| from functools import reduce |
| from inspect import isfunction |
| from math import ceil, floor, log2, pi |
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from torch import Generator, Tensor |
| from typing_extensions import TypeGuard |
|
|
| T = TypeVar("T") |
|
|
|
|
| def exists(val: Optional[T]) -> TypeGuard[T]: |
| return val is not None |
|
|
|
|
| def iff(condition: bool, value: T) -> Optional[T]: |
| return value if condition else None |
|
|
|
|
| def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]: |
| return isinstance(obj, list) or isinstance(obj, tuple) |
|
|
|
|
| def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
|
|
| def to_list(val: Union[T, Sequence[T]]) -> List[T]: |
| if isinstance(val, tuple): |
| return list(val) |
| if isinstance(val, list): |
| return val |
| return [val] |
|
|
|
|
| def prod(vals: Sequence[int]) -> int: |
| return reduce(lambda x, y: x * y, vals) |
|
|
|
|
| def closest_power_2(x: float) -> int: |
| exponent = log2(x) |
| distance_fn = lambda z: abs(x - 2 ** z) |
| exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) |
| return 2 ** int(exponent_closest) |
|
|
| def rand_bool(shape, proba, device = None): |
| if proba == 1: |
| return torch.ones(shape, device=device, dtype=torch.bool) |
| elif proba == 0: |
| return torch.zeros(shape, device=device, dtype=torch.bool) |
| else: |
| return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) |
|
|
|
|
| """ |
| Kwargs Utils |
| """ |
|
|
|
|
| def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: |
| return_dicts: Tuple[Dict, Dict] = ({}, {}) |
| for key in d.keys(): |
| no_prefix = int(not key.startswith(prefix)) |
| return_dicts[no_prefix][key] = d[key] |
| return return_dicts |
|
|
|
|
| def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: |
| kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) |
| if keep_prefix: |
| return kwargs_with_prefix, kwargs |
| kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} |
| return kwargs_no_prefix, kwargs |
|
|
|
|
| def prefix_dict(prefix: str, d: Dict) -> Dict: |
| return {prefix + str(k): v for k, v in d.items()} |
|
|