Spaces:
Running on Zero
Running on Zero
| import re | |
| import os.path as osp | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as tf | |
| from torch.utils.checkpoint import checkpoint | |
| import numpy as np | |
| import itertools | |
| import importlib | |
| from tqdm import tqdm | |
| from inspect import isfunction | |
| from functools import wraps | |
| from safetensors import safe_open | |
| def exists(x): | |
| return x is not None | |
| def append_dims(x, target_dims) -> torch.Tensor: | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def expand_to_batch_size(x, bs): | |
| if isinstance(x, list): | |
| x = [xi.repeat(bs, *([1] * (len(xi.shape) - 1))) for xi in x] | |
| else: | |
| x = x.repeat(bs, *([1] * (len(x.shape) - 1))) | |
| return x | |
| def get_obj_from_str(string, reload=False): | |
| module, cls = string.rsplit(".", 1) | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def instantiate_from_config(config): | |
| if not "target" in config: | |
| if config == '__is_first_stage__': | |
| return None | |
| elif config == "__is_unconditional__": | |
| return None | |
| raise KeyError("Expected key `target` to instantiate.") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def scaled_resize(x: torch.Tensor, scale_factor, interpolation_mode="bicubic"): | |
| return F.interpolate(x, scale_factor=scale_factor, mode=interpolation_mode) | |
| def get_crop_scale(h, w, bgh, bgw): | |
| gen_aspect = w / h | |
| bg_aspect = bgw / bgh | |
| if gen_aspect > bg_aspect: | |
| cw = 1.0 | |
| ch = (h / w) * (bgw / bgh) | |
| else: | |
| ch = 1.0 | |
| cw = (w / h) * (bgh / bgw) | |
| return ch, cw | |
| def warp_resize(x: torch.Tensor, target_size, interpolation_mode="bicubic"): | |
| assert len(x.shape) == 4 | |
| return F.interpolate(x, size=target_size, mode=interpolation_mode) | |
| def resize_and_crop(x: torch.Tensor, ch, cw, th, tw): | |
| b, c, h, w = x.shape | |
| return tf.resized_crop(x, 0, 0, int(ch * h), int(cw * w), size=[th, tw]) | |
| def fitting_weights(model, sd): | |
| n_params = len([name for name, _ in | |
| itertools.chain(model.named_parameters(), | |
| model.named_buffers())]) | |
| for name, param in tqdm( | |
| itertools.chain(model.named_parameters(), | |
| model.named_buffers()), | |
| desc="Fitting old weights to new weights", | |
| total=n_params | |
| ): | |
| if not name in sd: | |
| continue | |
| old_shape = sd[name].shape | |
| new_shape = param.shape | |
| assert len(old_shape) == len(new_shape) | |
| if len(new_shape) > 2: | |
| # we only modify first two axes | |
| assert new_shape[2:] == old_shape[2:] | |
| # assumes first axis corresponds to output dim | |
| if not new_shape == old_shape: | |
| new_param = param.clone() | |
| old_param = sd[name] | |
| device = old_param.device | |
| if len(new_shape) == 1: | |
| # Vectorized 1D case | |
| new_param = old_param[torch.arange(new_shape[0], device=device) % old_shape[0]] | |
| elif len(new_shape) >= 2: | |
| # Vectorized 2D case | |
| i_indices = torch.arange(new_shape[0], device=device)[:, None] % old_shape[0] | |
| j_indices = torch.arange(new_shape[1], device=device)[None, :] % old_shape[1] | |
| # Use advanced indexing to extract all values at once | |
| new_param = old_param[i_indices, j_indices] | |
| # Count how many times each old column is used | |
| n_used_old = torch.bincount( | |
| torch.arange(new_shape[1], device=device) % old_shape[1], | |
| minlength=old_shape[1] | |
| ) | |
| # Map to new shape | |
| n_used_new = n_used_old[torch.arange(new_shape[1], device=device) % old_shape[1]] | |
| # Reshape for broadcasting | |
| n_used_new = n_used_new.reshape(1, new_shape[1]) | |
| while len(n_used_new.shape) < len(new_shape): | |
| n_used_new = n_used_new.unsqueeze(-1) | |
| # Normalize | |
| new_param = new_param / n_used_new | |
| sd[name] = new_param | |
| return sd | |
| def count_params(model, verbose=False): | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| if verbose: | |
| print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") | |
| return total_params | |
| VALID_FORMATS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] | |
| def load_weights(path, weights_only=True): | |
| ext = osp.splitext(path)[-1] | |
| assert ext in VALID_FORMATS, f"Invalid checkpoint format {ext}" | |
| if ext == ".safetensors": | |
| sd = {} | |
| safe_sd = safe_open(path, framework="pt", device="cpu") | |
| for key in safe_sd.keys(): | |
| sd[key] = safe_sd.get_tensor(key) | |
| else: | |
| sd = torch.load(path, map_location="cpu", weights_only=weights_only) | |
| if "state_dict" in sd.keys(): | |
| sd = sd["state_dict"] | |
| return sd | |
| def delete_states(sd, delete_keys: list[str] = (), skip_keys: list[str] = ()): | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in delete_keys: | |
| if len(skip_keys) > 0: | |
| for sk in skip_keys: | |
| if re.match(ik, k) is not None and re.match(sk, k) is None: | |
| del sd[k] | |
| else: | |
| if re.match(ik, k) is not None: | |
| del sd[k] | |
| return sd | |
| def autocast(f, enabled=True): | |
| def do_autocast(*args, **kwargs): | |
| with torch.cuda.amp.autocast( | |
| enabled=enabled, | |
| dtype=torch.get_autocast_gpu_dtype(), | |
| cache_enabled=torch.is_autocast_cache_enabled(), | |
| ): | |
| return f(*args, **kwargs) | |
| return do_autocast | |
| def checkpoint_wrapper(func): | |
| def wrapper(self, *args, **kwargs): | |
| if not hasattr(self, 'checkpoint') or self.checkpoint: | |
| def bound_func(*args, **kwargs): | |
| return func(self, *args, **kwargs) | |
| return checkpoint(bound_func, *args, use_reentrant=False, **kwargs) | |
| else: | |
| return func(self, *args, **kwargs) | |
| return wrapper | |