Spaces:
Runtime error
Runtime error
import gguf | |
import torch | |
import os | |
import json | |
import safetensors.torch | |
import backend.misc.checkpoint_pickle | |
from backend.operations_gguf import ParameterGGUF | |
def read_arbitrary_config(directory): | |
config_path = os.path.join(directory, 'config.json') | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"No config.json file found in the directory: {directory}") | |
with open(config_path, 'rt', encoding='utf-8') as file: | |
config_data = json.load(file) | |
return config_data | |
def load_torch_file(ckpt, safe_load=False, device=None): | |
if device is None: | |
device = torch.device("cpu") | |
if ckpt.lower().endswith(".safetensors"): | |
sd = safetensors.torch.load_file(ckpt, device=device.type) | |
elif ckpt.lower().endswith(".gguf"): | |
reader = gguf.GGUFReader(ckpt) | |
sd = {} | |
for tensor in reader.tensors: | |
sd[str(tensor.name)] = ParameterGGUF(tensor) | |
else: | |
if safe_load: | |
if not 'weights_only' in torch.load.__code__.co_varnames: | |
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") | |
safe_load = False | |
if safe_load: | |
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) | |
else: | |
pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle) | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
if "state_dict" in pl_sd: | |
sd = pl_sd["state_dict"] | |
else: | |
sd = pl_sd | |
return sd | |
def set_attr(obj, attr, value): | |
attrs = attr.split(".") | |
for name in attrs[:-1]: | |
obj = getattr(obj, name) | |
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) | |
def set_attr_raw(obj, attr, value): | |
attrs = attr.split(".") | |
for name in attrs[:-1]: | |
obj = getattr(obj, name) | |
setattr(obj, attrs[-1], value) | |
def copy_to_param(obj, attr, value): | |
attrs = attr.split(".") | |
for name in attrs[:-1]: | |
obj = getattr(obj, name) | |
prev = getattr(obj, attrs[-1]) | |
prev.data.copy_(value) | |
def get_attr(obj, attr): | |
attrs = attr.split(".") | |
for name in attrs: | |
obj = getattr(obj, name) | |
return obj | |
def get_attr_with_parent(obj, attr): | |
attrs = attr.split(".") | |
parent = obj | |
name = None | |
for name in attrs: | |
parent = obj | |
obj = getattr(obj, name) | |
return parent, name, obj | |
def calculate_parameters(sd, prefix=""): | |
params = 0 | |
for k in sd.keys(): | |
if k.startswith(prefix): | |
params += sd[k].nelement() | |
return params | |
def tensor2parameter(x): | |
if isinstance(x, torch.nn.Parameter): | |
return x | |
else: | |
return torch.nn.Parameter(x, requires_grad=False) | |
def fp16_fix(x): | |
# An interesting trick to avoid fp16 overflow | |
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114 | |
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180 | |
if x.dtype in [torch.float16]: | |
return x.clip(-32768.0, 32768.0) | |
return x | |
def nested_compute_size(obj): | |
module_mem = 0 | |
if isinstance(obj, dict): | |
for key in obj: | |
module_mem += nested_compute_size(obj[key]) | |
elif isinstance(obj, list) or isinstance(obj, tuple): | |
for i in range(len(obj)): | |
module_mem += nested_compute_size(obj[i]) | |
elif isinstance(obj, torch.Tensor): | |
module_mem += obj.nelement() * obj.element_size() | |
return module_mem | |
def nested_move_to_device(obj, device): | |
if isinstance(obj, dict): | |
for key in obj: | |
obj[key] = nested_move_to_device(obj[key], device) | |
elif isinstance(obj, list): | |
for i in range(len(obj)): | |
obj[i] = nested_move_to_device(obj[i], device) | |
elif isinstance(obj, tuple): | |
obj = tuple(nested_move_to_device(i, device) for i in obj) | |
elif isinstance(obj, torch.Tensor): | |
return obj.to(device) | |
return obj | |
def get_state_dict_after_quant(model, prefix=''): | |
for m in model.modules(): | |
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'): | |
if not m.weight.bnb_quantized: | |
original_device = m.weight.device | |
m.cuda() | |
m.to(original_device) | |
sd = model.state_dict() | |
sd = {(prefix + k): v.clone() for k, v in sd.items()} | |
return sd | |