Spaces:
Runtime error
Runtime error
| import gguf | |
| import torch | |
| import os | |
| import json | |
| import safetensors.torch | |
| import backend.misc.checkpoint_pickle | |
| from backend.operations_gguf import ParameterGGUF | |
| def read_arbitrary_config(directory): | |
| config_path = os.path.join(directory, 'config.json') | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"No config.json file found in the directory: {directory}") | |
| with open(config_path, 'rt', encoding='utf-8') as file: | |
| config_data = json.load(file) | |
| return config_data | |
| def load_torch_file(ckpt, safe_load=False, device=None): | |
| if device is None: | |
| device = torch.device("cpu") | |
| if ckpt.lower().endswith(".safetensors"): | |
| sd = safetensors.torch.load_file(ckpt, device=device.type) | |
| elif ckpt.lower().endswith(".gguf"): | |
| reader = gguf.GGUFReader(ckpt) | |
| sd = {} | |
| for tensor in reader.tensors: | |
| sd[str(tensor.name)] = ParameterGGUF(tensor) | |
| else: | |
| if safe_load: | |
| if not 'weights_only' in torch.load.__code__.co_varnames: | |
| print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") | |
| safe_load = False | |
| if safe_load: | |
| pl_sd = torch.load(ckpt, map_location=device, weights_only=True) | |
| else: | |
| pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle) | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| if "state_dict" in pl_sd: | |
| sd = pl_sd["state_dict"] | |
| else: | |
| sd = pl_sd | |
| return sd | |
| def set_attr(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) | |
| def set_attr_raw(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| setattr(obj, attrs[-1], value) | |
| def copy_to_param(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| prev = getattr(obj, attrs[-1]) | |
| prev.data.copy_(value) | |
| def get_attr(obj, attr): | |
| attrs = attr.split(".") | |
| for name in attrs: | |
| obj = getattr(obj, name) | |
| return obj | |
| def get_attr_with_parent(obj, attr): | |
| attrs = attr.split(".") | |
| parent = obj | |
| name = None | |
| for name in attrs: | |
| parent = obj | |
| obj = getattr(obj, name) | |
| return parent, name, obj | |
| def calculate_parameters(sd, prefix=""): | |
| params = 0 | |
| for k in sd.keys(): | |
| if k.startswith(prefix): | |
| params += sd[k].nelement() | |
| return params | |
| def tensor2parameter(x): | |
| if isinstance(x, torch.nn.Parameter): | |
| return x | |
| else: | |
| return torch.nn.Parameter(x, requires_grad=False) | |
| def fp16_fix(x): | |
| # An interesting trick to avoid fp16 overflow | |
| # Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114 | |
| # Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180 | |
| if x.dtype in [torch.float16]: | |
| return x.clip(-32768.0, 32768.0) | |
| return x | |
| def nested_compute_size(obj): | |
| module_mem = 0 | |
| if isinstance(obj, dict): | |
| for key in obj: | |
| module_mem += nested_compute_size(obj[key]) | |
| elif isinstance(obj, list) or isinstance(obj, tuple): | |
| for i in range(len(obj)): | |
| module_mem += nested_compute_size(obj[i]) | |
| elif isinstance(obj, torch.Tensor): | |
| module_mem += obj.nelement() * obj.element_size() | |
| return module_mem | |
| def nested_move_to_device(obj, device): | |
| if isinstance(obj, dict): | |
| for key in obj: | |
| obj[key] = nested_move_to_device(obj[key], device) | |
| elif isinstance(obj, list): | |
| for i in range(len(obj)): | |
| obj[i] = nested_move_to_device(obj[i], device) | |
| elif isinstance(obj, tuple): | |
| obj = tuple(nested_move_to_device(i, device) for i in obj) | |
| elif isinstance(obj, torch.Tensor): | |
| return obj.to(device) | |
| return obj | |
| def get_state_dict_after_quant(model, prefix=''): | |
| for m in model.modules(): | |
| if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'): | |
| if not m.weight.bnb_quantized: | |
| original_device = m.weight.device | |
| m.cuda() | |
| m.to(original_device) | |
| sd = model.state_dict() | |
| sd = {(prefix + k): v.clone() for k, v in sd.items()} | |
| return sd | |