import glob import inspect import os import sys import traceback import torch from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_ try: from modules.hashes import sha256 except (ImportError, ModuleNotFoundError): print("modules.hashes is not found, will use backup module from extension!") from .hashes_backup import sha256 import modules.hypernetworks.hypernetwork from modules import devices, shared, sd_models from .hnutil import parse_dropout_structure, find_self from .shared import version_flag def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"): w, b = layer.weight.data, layer.bias.data if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: normal_(w, mean=0.0, std=normal_std) normal_(b, mean=0.0, std=0) elif weight_init == 'XavierUniform': xavier_uniform_(w) zeros_(b) elif weight_init == 'XavierNormal': xavier_normal_(w) zeros_(b) elif weight_init == 'KaimingUniform': kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') zeros_(b) elif weight_init == 'KaimingNormal': kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') zeros_(b) else: raise KeyError(f"Key {weight_init} is not defined as initialization!") class ResBlock(torch.nn.Module): """Residual Block""" def __init__(self, n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device=None, state_dict=None, **kwargs): super().__init__() self.n_outputs = n_outputs self.upsample_layer = None self.upsample = kwargs.get("upsample_model", None) if self.upsample == "Linear": self.upsample_layer = torch.nn.Linear(n_inputs, n_outputs, bias=False) linears = [torch.nn.Linear(n_inputs, n_outputs)] init_weight(linears[0], weight_init, normal_std, activation_func) if add_layer_norm: linears.append(torch.nn.LayerNorm(n_outputs)) init_weight(linears[1], weight_init, normal_std, activation_func) if dropout_p > 0: linears.append(torch.nn.Dropout(p=dropout_p)) if activation_func == "linear" or activation_func is None: pass elif activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') self.linear = torch.nn.Sequential(*linears) if state_dict is not None: self.load_state_dict(state_dict) if device is not None: self.to(device) def trainables(self, train=False): layer_structure = [] for layer in self.linear: if train: layer.train() else: layer.eval() if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer_structure += [layer.weight, layer.bias] return layer_structure def forward(self, x, **kwargs): if self.upsample_layer is None: interpolated = torch.nn.functional.interpolate(x, size=self.n_outputs, mode="nearest-exact") else: interpolated = self.upsample_layer(x) return interpolated + self.linear(x) class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { "linear": torch.nn.Identity, "relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, "swish": torch.nn.Hardswish, "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, } activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, activate_output=False, dropout_structure=None, device=None, generation_seed=None, normal_std=0.01, **kwargs): super().__init__() self.skip_connection = skip_connection = kwargs.get('skip_connection', False) upsample_linear = kwargs.get('upsample_linear', None) assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" # instead of throwing error, maybe try warning. first value is always not used. if not (skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0): print("Dropout sequence does not starts or ends with zero.") # assert skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0, "Dropout Sequence should start and end with probability 0!" assert dropout_structure is None or len(dropout_structure) == len(layer_structure), "Dropout Sequence should match length with layer structure!" linears = [] if skip_connection: if generation_seed is not None: torch.manual_seed(generation_seed) for i in range(len(layer_structure) - 1): if skip_connection: n_inputs, n_outputs = int(dim * layer_structure[i]), int(dim * layer_structure[i+1]) dropout_p = dropout_structure[i+1] if activation_func is None: activation_func = "linear" linears.append(ResBlock(n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device, upsample_model=upsample_linear)) continue # Add a fully-connected layer linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # Add an activation func except last layer if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) # Everything should be now parsed into dropout structure, and applied here. # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0. if dropout_structure is not None and dropout_structure[i+1] > 0: assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!" linears.append(torch.nn.Dropout(p=dropout_structure[i+1])) # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0]. self.linear = torch.nn.Sequential(*linears) if state_dict is not None: self.fix_old_state_dict(state_dict) self.load_state_dict(state_dict) elif not skip_connection: if generation_seed is not None: torch.manual_seed(generation_seed) for layer in self.linear: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: w, b = layer.weight.data, layer.bias.data if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: normal_(w, mean=0.0, std=normal_std) normal_(b, mean=0.0, std=0) elif weight_init == 'XavierUniform': xavier_uniform_(w) zeros_(b) elif weight_init == 'XavierNormal': xavier_normal_(w) zeros_(b) elif weight_init == 'KaimingUniform': kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') zeros_(b) elif weight_init == 'KaimingNormal': kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') zeros_(b) else: raise KeyError(f"Key {weight_init} is not defined as initialization!") if device is None: self.to(devices.device) else: self.to(device) def fix_old_state_dict(self, state_dict): changes = { 'linear1.bias': 'linear.0.bias', 'linear1.weight': 'linear.0.weight', 'linear2.bias': 'linear.1.bias', 'linear2.weight': 'linear.1.weight', } for fr, to in changes.items(): x = state_dict.get(fr, None) if x is None: continue del state_dict[fr] state_dict[to] = x def forward(self, x, multiplier=None): if self.skip_connection: if self.training: return self.linear(x) else: resnet_result = self.linear(x) residual = resnet_result - x if multiplier is None or not isinstance(multiplier, (int, float)): multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier return x + multiplier * residual # interpolate if multiplier is None or not isinstance(multiplier, (int, float)): return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1) return x + self.linear(x) * multiplier def trainables(self, train=False): layer_structure = [] self.train(train) for layer in self.linear: if train: layer.train() else: layer.eval() if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer_structure += [layer.weight, layer.bias] elif type(layer) == ResBlock: layer_structure += layer.trainables(train) return layer_structure def set_train(self,mode=True): self.train(mode) for layer in self.linear: if mode: layer.train(mode) else: layer.eval() class Hypernetwork: filename = None name = None def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs): self.filename = None self.name = name self.layers = {} self.step = 0 self.sd_checkpoint = None self.sd_checkpoint_name = None self.layer_structure = layer_structure self.activation_func = activation_func self.weight_init = weight_init self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout self.activate_output = activate_output self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True self.optimizer_name = None self.optimizer_state_dict = None self.dropout_structure = kwargs['dropout_structure'] if 'dropout_structure' in kwargs and use_dropout else None self.optional_info = kwargs.get('optional_info', None) self.skip_connection = kwargs.get('skip_connection', False) self.upsample_linear = kwargs.get('upsample_linear', None) self.training = False generation_seed = kwargs.get('generation_seed', None) normal_std = kwargs.get('normal_std', 0.01) if self.dropout_structure is None: self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) for size in enable_sizes or []: self.layers[size] = ( HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear), ) self.eval() def weights(self, train=False): self.training = train res = [] for k, layers in self.layers.items(): for layer in layers: res += layer.trainables(train) return res def eval(self): self.training = False for k, layers in self.layers.items(): for layer in layers: layer.eval() layer.set_train(False) def train(self, mode=True): self.training = mode for k, layers in self.layers.items(): for layer in layers: layer.set_train(mode) def detach_grad(self): for k, layers in self.layers.items(): for layer in layers: layer.requires_grad_(False) def shorthash(self): sha256v = sha256(self.filename, f'hypernet/{self.name}') return sha256v[0:10] def extra_name(self): if version_flag: return "" found = find_self(self) if found is not None: return f" " return f" " def save(self, filename): state_dict = {} optimizer_saved_dict = {} for k, v in self.layers.items(): state_dict[k] = (v[0].state_dict(), v[1].state_dict()) state_dict['step'] = self.step state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure state_dict['activation_func'] = self.activation_func state_dict['is_layer_norm'] = self.add_layer_norm state_dict['weight_initialization'] = self.weight_init state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['activate_output'] = self.activate_output state_dict['use_dropout'] = self.use_dropout state_dict['dropout_structure'] = self.dropout_structure state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout state_dict['optional_info'] = self.optional_info if self.optional_info else None state_dict['skip_connection'] = self.skip_connection state_dict['upsample_linear'] = self.upsample_linear if self.optimizer_name is not None: optimizer_saved_dict['optimizer_name'] = self.optimizer_name torch.save(state_dict, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict: optimizer_saved_dict['hash'] = self.shorthash() # this is necessary optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(optimizer_saved_dict, filename + '.optim') def load(self, filename): self.filename = filename if self.name is None: self.name = os.path.splitext(os.path.basename(filename))[0] state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) print(self.layer_structure) optional_info = state_dict.get('optional_info', None) if optional_info is not None: self.optional_info = optional_info self.activation_func = state_dict.get('activation_func', None) self.weight_init = state_dict.get('weight_initialization', 'Normal') self.add_layer_norm = state_dict.get('is_layer_norm', False) self.dropout_structure = state_dict.get('dropout_structure', None) self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) self.activate_output = state_dict.get('activate_output', True) self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Silent fix for HNs before 4918eb6 self.skip_connection = state_dict.get('skip_connection', False) self.upsample_linear = state_dict.get('upsample_linear', False) # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. if self.dropout_structure is None: self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) if hasattr(shared.opts, 'print_hypernet_extra') and shared.opts.print_hypernet_extra: if optional_info is not None: print(f"INFO:\n {optional_info}\n") print(f"Activation function is {self.activation_func}") print(f"Weight initialization is {self.weight_init}") print(f"Layer norm is set to {self.add_layer_norm}") print(f"Dropout usage is set to {self.use_dropout}") print(f"Activate last layer is set to {self.activate_output}") print(f"Dropout structure is set to {self.dropout_structure}") optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} self.optimizer_name = state_dict.get('optimizer_name', 'AdamW') if optimizer_saved_dict.get('hash', None) == self.shorthash() or optimizer_saved_dict.get('hash', None) == sd_models.model_hash(filename): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None if self.optimizer_state_dict: self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print("Loaded existing optimizer from checkpoint") print(f"Optimizer name is {self.optimizer_name}") else: print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.eval() def to(self, device): for k, layers in self.layers.items(): for layer in layers: layer.to(device) return self def set_multiplier(self, multiplier): for k, layers in self.layers.items(): for layer in layers: layer.multiplier = multiplier return self def __call__(self, context, *args, **kwargs): return self.forward(context, *args, **kwargs) def forward(self, context, context_v=None, layer=None): context_layers = self.layers.get(context.shape[2], None) if context_v is None: context_v = context if context_layers is None: return context, context_v if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): layer.hyper_k = context_layers[0] layer.hyper_v = context_layers[1] transform_k, transform_v = context_layers[0](context), context_layers[1](context_v) return transform_k, transform_v def list_hypernetworks(path): res = {} for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): name = os.path.splitext(os.path.basename(filename))[0] idx = 0 while name in res: idx += 1 name = name + f"({idx})" # Prevent a hypothetical "None.pt" from being listed. if name != "None": res[name] = filename for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True): name = os.path.splitext(os.path.basename(filename))[0] if name != "None": res[name] = filename return res def find_closest_first(keyset, target): for keys in keyset: if target == keys.rsplit('(', 1)[0]: return keys return None def load_hypernetwork(filename): hypernetwork = None path = shared.hypernetworks.get(filename, None) if path is None: filename = find_closest_first(shared.hypernetworks.keys(), filename) path = shared.hypernetworks.get(filename, None) print(path) # Prevent any file named "None.pt" from being loaded. if path is not None and filename != "None": print(f"Loading hypernetwork {filename}") if path.endswith(".pt"): try: hypernetwork = Hypernetwork() hypernetwork.load(path) if hasattr(shared, 'loaded_hypernetwork'): shared.loaded_hypernetwork = hypernetwork else: return hypernetwork except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) elif path.endswith(".hns"): # Load Hypernetwork processing try: from .hypernetworks import load as load_hns if hasattr(shared, 'loaded_hypernetwork'): shared.loaded_hypernetwork = load_hns(path) else: hypernetwork = load_hns(path) print(f"Loaded Hypernetwork Structure {path}") return hypernetwork except Exception: print(f"Error loading hypernetwork processing file {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: print(f"Tried to load unknown file extension: {filename}") else: if hasattr(shared, 'loaded_hypernetwork'): if shared.loaded_hypernetwork is not None: print(f"Unloading hypernetwork") shared.loaded_hypernetwork = None return hypernetwork def apply_hypernetwork(hypernetwork, context, layer=None): if hypernetwork is None: return context, context if isinstance(hypernetwork, Hypernetwork): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is None: return context, context if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context))) context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context))) return context_k, context_v context_k, context_v = hypernetwork(context, layer=layer) return context_k, context_v def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): if hypernetwork is None: return context_k, context_v if isinstance(hypernetwork, Hypernetwork): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) if hypernetwork_layers is None: return context_k, context_v if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k))) context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v))) return context_k, context_v context_k, context_v = hypernetwork(context_k, context_v, layer=layer) return context_k, context_v def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength def apply_hypernetwork_strength(p, x, xs): apply_strength(x) modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'): modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork else: modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'): modules.hypernetworks.hypernetwork.apply_strength = apply_strength modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule try: import scripts.xy_grid if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'): scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength except (ModuleNotFoundError, ImportError): pass