Spaces:
Build error
Build error
from __future__ import annotations | |
import functools | |
import importlib | |
import os | |
from functools import partial | |
from inspect import isfunction | |
import fsspec | |
import torch | |
from einops import repeat | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" | |
return self | |
def get_string_from_tuple(s): | |
try: | |
# check if the string starts and ends with parentheses | |
if s[0] == "(" and s[-1] == ")": | |
# convert the string to a tuple | |
t = eval(s) | |
# check if the type of t is tuple | |
if isinstance(t, tuple): | |
return t[0] | |
else: | |
pass | |
except: | |
pass | |
return s | |
def is_power_of_two(n): | |
"""Return True if n is a power of 2, otherwise return False.""" | |
if n <= 0: | |
return False | |
else: | |
return (n & (n - 1)) == 0 | |
def autocast(f, enabled=True): | |
def do_autocast(*args, **kwargs): | |
with torch.cuda.amp.autocast( | |
enabled=enabled, | |
dtype=torch.get_autocast_gpu_dtype(), | |
cache_enabled=torch.is_autocast_cache_enabled() | |
): | |
return f(*args, **kwargs) | |
return do_autocast | |
def load_partial_from_config(config): | |
return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) | |
def repeat_as_img_seq(x, num_frames): | |
if x is not None: | |
if isinstance(x, list): | |
new_x = list() | |
for item_x in x: | |
new_x += [item_x] * num_frames | |
return new_x | |
else: | |
x = x.unsqueeze(1) | |
x = repeat(x, "b 1 ... -> (b t) ...", t=num_frames) | |
return x | |
else: | |
return None | |
def partialclass(cls, *args, **kwargs): | |
class NewCls(cls): | |
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs) | |
return NewCls | |
def make_path_absolute(path): | |
fs, p = fsspec.core.url_to_fs(path) | |
if fs.protocol == "file": | |
return os.path.abspath(p) | |
else: | |
return path | |
def ismap(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
else: | |
return (len(x.shape) == 4) and (x.shape[1] > 3) | |
def isimage(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
else: | |
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) | |
def isheatmap(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
else: | |
return x.ndim == 2 | |
def isneighbors(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
else: | |
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) | |
def exists(x): | |
return x is not None | |
def expand_dims_like(x, y): | |
while x.dim() != y.dim(): | |
x = x.unsqueeze(-1) | |
return x | |
def default(val, d): | |
if exists(val): | |
return val | |
else: | |
return d() if isfunction(d) else d | |
def mean_flat(tensor): | |
""" | |
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 | |
Take the mean over all non-batch dimensions. | |
""" | |
return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
def count_params(model, verbose=False): | |
total_params = sum(p.numel() for p in model.parameters()) | |
if verbose: | |
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params") | |
return total_params | |
def instantiate_from_config(config): | |
if "target" not in config: | |
if config == "__is_first_stage__": | |
return None | |
elif config == "__is_unconditional__": | |
return None | |
else: | |
raise KeyError("Expected key `target` to instantiate") | |
else: | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def get_obj_from_str(string, reload=False, invalidate_cache=True): | |
module, cls = string.rsplit(".", 1) | |
if invalidate_cache: | |
importlib.invalidate_caches() | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def append_zero(x): | |
return torch.cat((x, x.new_zeros([1]))) | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f"Input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
return x[(...,) + (None,) * dims_to_append] | |
def get_configs_path() -> str: | |
"""Get the `configs` directory.""" | |
this_dir = os.path.dirname(__file__) | |
candidates = ( | |
os.path.join(this_dir, "configs"), | |
os.path.join(this_dir, "..", "configs") | |
) | |
for candidate in candidates: | |
candidate = os.path.abspath(candidate) | |
if os.path.isdir(candidate): | |
return candidate | |
raise FileNotFoundError(f"Could not find configs in {candidates}") | |