raoulduke420's picture
Upload folder using huggingface_hub
ef9fd1f
import glob
import inspect
import os
import sys
import traceback
import torch
from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_
try:
from modules.hashes import sha256
except (ImportError, ModuleNotFoundError):
print("modules.hashes is not found, will use backup module from extension!")
from .hashes_backup import sha256
import modules.hypernetworks.hypernetwork
from modules import devices, shared, sd_models
from .hnutil import parse_dropout_structure, find_self
from .shared import version_flag
def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"):
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=normal_std)
normal_(b, mean=0.0, std=0)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
elif weight_init == 'XavierNormal':
xavier_normal_(w)
zeros_(b)
elif weight_init == 'KaimingUniform':
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
elif weight_init == 'KaimingNormal':
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
else:
raise KeyError(f"Key {weight_init} is not defined as initialization!")
class ResBlock(torch.nn.Module):
"""Residual Block"""
def __init__(self, n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device=None, state_dict=None, **kwargs):
super().__init__()
self.n_outputs = n_outputs
self.upsample_layer = None
self.upsample = kwargs.get("upsample_model", None)
if self.upsample == "Linear":
self.upsample_layer = torch.nn.Linear(n_inputs, n_outputs, bias=False)
linears = [torch.nn.Linear(n_inputs, n_outputs)]
init_weight(linears[0], weight_init, normal_std, activation_func)
if add_layer_norm:
linears.append(torch.nn.LayerNorm(n_outputs))
init_weight(linears[1], weight_init, normal_std, activation_func)
if dropout_p > 0:
linears.append(torch.nn.Dropout(p=dropout_p))
if activation_func == "linear" or activation_func is None:
pass
elif activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
self.load_state_dict(state_dict)
if device is not None:
self.to(device)
def trainables(self, train=False):
layer_structure = []
for layer in self.linear:
if train:
layer.train()
else:
layer.eval()
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
def forward(self, x, **kwargs):
if self.upsample_layer is None:
interpolated = torch.nn.functional.interpolate(x, size=self.n_outputs, mode="nearest-exact")
else:
interpolated = self.upsample_layer(x)
return interpolated + self.linear(x)
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, activate_output=False, dropout_structure=None, device=None, generation_seed=None, normal_std=0.01, **kwargs):
super().__init__()
self.skip_connection = skip_connection = kwargs.get('skip_connection', False)
upsample_linear = kwargs.get('upsample_linear', None)
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
# instead of throwing error, maybe try warning. first value is always not used.
if not (skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0):
print("Dropout sequence does not starts or ends with zero.")
# assert skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0, "Dropout Sequence should start and end with probability 0!"
assert dropout_structure is None or len(dropout_structure) == len(layer_structure), "Dropout Sequence should match length with layer structure!"
linears = []
if skip_connection:
if generation_seed is not None:
torch.manual_seed(generation_seed)
for i in range(len(layer_structure) - 1):
if skip_connection:
n_inputs, n_outputs = int(dim * layer_structure[i]), int(dim * layer_structure[i+1])
dropout_p = dropout_structure[i+1]
if activation_func is None:
activation_func = "linear"
linears.append(ResBlock(n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device, upsample_model=upsample_linear))
continue
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func except last layer
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Everything should be now parsed into dropout structure, and applied here.
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
if dropout_structure is not None and dropout_structure[i+1] > 0:
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
self.fix_old_state_dict(state_dict)
self.load_state_dict(state_dict)
elif not skip_connection:
if generation_seed is not None:
torch.manual_seed(generation_seed)
for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=normal_std)
normal_(b, mean=0.0, std=0)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
elif weight_init == 'XavierNormal':
xavier_normal_(w)
zeros_(b)
elif weight_init == 'KaimingUniform':
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
elif weight_init == 'KaimingNormal':
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
else:
raise KeyError(f"Key {weight_init} is not defined as initialization!")
if device is None:
self.to(devices.device)
else:
self.to(device)
def fix_old_state_dict(self, state_dict):
changes = {
'linear1.bias': 'linear.0.bias',
'linear1.weight': 'linear.0.weight',
'linear2.bias': 'linear.1.bias',
'linear2.weight': 'linear.1.weight',
}
for fr, to in changes.items():
x = state_dict.get(fr, None)
if x is None:
continue
del state_dict[fr]
state_dict[to] = x
def forward(self, x, multiplier=None):
if self.skip_connection:
if self.training:
return self.linear(x)
else:
resnet_result = self.linear(x)
residual = resnet_result - x
if multiplier is None or not isinstance(multiplier, (int, float)):
multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier
return x + multiplier * residual # interpolate
if multiplier is None or not isinstance(multiplier, (int, float)):
return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1)
return x + self.linear(x) * multiplier
def trainables(self, train=False):
layer_structure = []
self.train(train)
for layer in self.linear:
if train:
layer.train()
else:
layer.eval()
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
elif type(layer) == ResBlock:
layer_structure += layer.trainables(train)
return layer_structure
def set_train(self,mode=True):
self.train(mode)
for layer in self.linear:
if mode:
layer.train(mode)
else:
layer.eval()
class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
self.step = 0
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.activation_func = activation_func
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
self.optimizer_name = None
self.optimizer_state_dict = None
self.dropout_structure = kwargs['dropout_structure'] if 'dropout_structure' in kwargs and use_dropout else None
self.optional_info = kwargs.get('optional_info', None)
self.skip_connection = kwargs.get('skip_connection', False)
self.upsample_linear = kwargs.get('upsample_linear', None)
self.training = False
generation_seed = kwargs.get('generation_seed', None)
normal_std = kwargs.get('normal_std', 0.01)
if self.dropout_structure is None:
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection,
upsample_linear=self.upsample_linear),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection,
upsample_linear=self.upsample_linear),
)
self.eval()
def weights(self, train=False):
self.training = train
res = []
for k, layers in self.layers.items():
for layer in layers:
res += layer.trainables(train)
return res
def eval(self):
self.training = False
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
layer.set_train(False)
def train(self, mode=True):
self.training = mode
for k, layers in self.layers.items():
for layer in layers:
layer.set_train(mode)
def detach_grad(self):
for k, layers in self.layers.items():
for layer in layers:
layer.requires_grad_(False)
def shorthash(self):
sha256v = sha256(self.filename, f'hypernet/{self.name}')
return sha256v[0:10]
def extra_name(self):
if version_flag:
return ""
found = find_self(self)
if found is not None:
return f" <hypernet:{found}:1.0>"
return f" <hypernet:{self.name}:1.0>"
def save(self, filename):
state_dict = {}
optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['weight_initialization'] = self.weight_init
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['use_dropout'] = self.use_dropout
state_dict['dropout_structure'] = self.dropout_structure
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
state_dict['optional_info'] = self.optional_info if self.optional_info else None
state_dict['skip_connection'] = self.skip_connection
state_dict['upsample_linear'] = self.upsample_linear
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
optimizer_saved_dict['hash'] = self.shorthash() # this is necessary
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
if self.name is None:
self.name = os.path.splitext(os.path.basename(filename))[0]
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
print(self.layer_structure)
optional_info = state_dict.get('optional_info', None)
if optional_info is not None:
self.optional_info = optional_info
self.activation_func = state_dict.get('activation_func', None)
self.weight_init = state_dict.get('weight_initialization', 'Normal')
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
self.activate_output = state_dict.get('activate_output', True)
self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Silent fix for HNs before 4918eb6
self.skip_connection = state_dict.get('skip_connection', False)
self.upsample_linear = state_dict.get('upsample_linear', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
if hasattr(shared.opts, 'print_hypernet_extra') and shared.opts.print_hypernet_extra:
if optional_info is not None:
print(f"INFO:\n {optional_info}\n")
print(f"Activation function is {self.activation_func}")
print(f"Weight initialization is {self.weight_init}")
print(f"Layer norm is set to {self.add_layer_norm}")
print(f"Dropout usage is set to {self.use_dropout}")
print(f"Activate last layer is set to {self.activate_output}")
print(f"Dropout structure is set to {self.dropout_structure}")
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
if optimizer_saved_dict.get('hash', None) == self.shorthash() or optimizer_saved_dict.get('hash', None) == sd_models.model_hash(filename):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print("Loaded existing optimizer from checkpoint")
print(f"Optimizer name is {self.optimizer_name}")
else:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear),
)
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
def to(self, device):
for k, layers in self.layers.items():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
for k, layers in self.layers.items():
for layer in layers:
layer.multiplier = multiplier
return self
def __call__(self, context, *args, **kwargs):
return self.forward(context, *args, **kwargs)
def forward(self, context, context_v=None, layer=None):
context_layers = self.layers.get(context.shape[2], None)
if context_v is None:
context_v = context
if context_layers is None:
return context, context_v
if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'):
layer.hyper_k = context_layers[0]
layer.hyper_v = context_layers[1]
transform_k, transform_v = context_layers[0](context), context_layers[1](context_v)
return transform_k, transform_v
def list_hypernetworks(path):
res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
idx = 0
while name in res:
idx += 1
name = name + f"({idx})"
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
if name != "None":
res[name] = filename
return res
def find_closest_first(keyset, target):
for keys in keyset:
if target == keys.rsplit('(', 1)[0]:
return keys
return None
def load_hypernetwork(filename):
hypernetwork = None
path = shared.hypernetworks.get(filename, None)
if path is None:
filename = find_closest_first(shared.hypernetworks.keys(), filename)
path = shared.hypernetworks.get(filename, None)
print(path)
# Prevent any file named "None.pt" from being loaded.
if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
if path.endswith(".pt"):
try:
hypernetwork = Hypernetwork()
hypernetwork.load(path)
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = hypernetwork
else:
return hypernetwork
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
elif path.endswith(".hns"):
# Load Hypernetwork processing
try:
from .hypernetworks import load as load_hns
if hasattr(shared, 'loaded_hypernetwork'):
shared.loaded_hypernetwork = load_hns(path)
else:
hypernetwork = load_hns(path)
print(f"Loaded Hypernetwork Structure {path}")
return hypernetwork
except Exception:
print(f"Error loading hypernetwork processing file {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
print(f"Tried to load unknown file extension: {filename}")
else:
if hasattr(shared, 'loaded_hypernetwork'):
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
return hypernetwork
def apply_hypernetwork(hypernetwork, context, layer=None):
if hypernetwork is None:
return context, context
if isinstance(hypernetwork, Hypernetwork):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context)))
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context)))
return context_k, context_v
context_k, context_v = hypernetwork(context, layer=layer)
return context_k, context_v
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
if hypernetwork is None:
return context_k, context_v
if isinstance(hypernetwork, Hypernetwork):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
return context_k, context_v
context_k, context_v = hypernetwork(context_k, context_v, layer=layer)
return context_k, context_v
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
def apply_hypernetwork_strength(p, x, xs):
apply_strength(x)
modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks
modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork
if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'):
modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork
else:
modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork
if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'):
modules.hypernetworks.hypernetwork.apply_strength = apply_strength
modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork
modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule
try:
import scripts.xy_grid
if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'):
scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength
except (ModuleNotFoundError, ImportError):
pass