import ast import os.path import torch from modules import devices, shared from .hnutil import find_self from .shared import version_flag lazy_load = False # when this is enabled, HNs will be loaded when required. if not hasattr(devices, 'cond_cast_unet'): raise RuntimeError("Cannot find cond_cast_unet attribute, please update your webui version!") class DynamicDict(dict): # Brief dict that dynamically unloads Hypernetworks if required. def __init__(self, **kwargs): super().__init__(**kwargs) self.current = None self.hash = None self.dict = {**kwargs} def prepare(self, key, value): if lazy_load and self.current is not None and ( key != self.current): # or filename is identical, but somehow hash is changed? self.current.to('cpu') self.current = value if self.current is not None: self.current.to(devices.device) def __getitem__(self, item): value = self.dict[item] self.prepare(item, value) return value def __setitem__(self, key, value): if key in self.dict: return self.dict[key] = value def __contains__(self, item): return item in self.dict available_opts = DynamicDict() # string -> HN itself. # Behavior definition. # [[], [], []] -> sequential processing # [{"A" : 0.8, "B" : 0.1}] -> parallel processing. with weighted sum in this case, A = 8/9 effect, B = 1/9 effect. # [("A", 0.2), ("B", 0.4)] -> tuple is used to specify strength. # [{"A", "B", "C"}] -> parallel, but having same effects (set) # ["A", "B", []] -> sequential processing # [{"A":0.6}, "B", "C"] -> sequential, dict with single value will be considered as strength modification. # [["A"], {"B"}, "C"] -> singletons are equal to items without covers, nested singleton will not be parsed, because its inefficient. # {{'Aa' : 0.2, 'Ab' : 0.8} : 0.8, 'B' : 0.1} (X) -> {"{'Aa' : 0.2, 'Ab' : 0.8}" : 0.8, 'B' : 0.1} (O), When you want complex setups in parallel, you need to cover them with "". You can use backslash too. # Testing parsing function. def test_parsing(string=None): def test(arg): print(arg) try: obj = str(Forward.parse(arg)) print(obj) except Exception as e: print(e) if string: test(string) else: for strings in ["[[], [], []]", "[{\"A\" : 0.8, \"B\" : 0.1}]", '[("A", 0.2), ("B", 0.4)]', '[{"A", "B", "C"}]', '[{"A":0.6}, "B", "C"]', '[["A"], {"B"}, "C"]', '{"{\'Aa\' : 0.2, \'Ab\' : 0.8}" : 0.8, \'B\' : 0.1}']: test(strings) class Forward: def __init__(self, **kwargs): self.name = "defaultForward" if 'name' not in kwargs else kwargs['name'] pass def __call__(self, *args, **kwargs): raise NotImplementedError def set_multiplier(self, *args, **kwargs): pass def extra_name(self): if version_flag: return "" found = find_self(self) if found is not None: return f" " return f" " @staticmethod def parse(arg, name=None): arg = Forward.unpack(arg) arg = Forward.eval(arg) if Forward.isSingleTon(arg): return SingularForward(*Forward.parseSingleTon(arg)) elif Forward.isParallel(arg): return ParallelForward(Forward.parseParallel(arg), name=name) elif Forward.isSequential(arg): return SequentialForward(Forward.parseSequential(arg), name=name) raise ValueError(f"Cannot parse {arg} into sequences!") @staticmethod def unpack(arg): # stop using ({({{((a))}})}) please if len(arg) == 1 and type(arg) in (set, list, tuple): return Forward.unpack(list(arg)[0]) if len(arg) == 1 and type(arg) is dict: key = list(arg.keys())[0] if arg[key] == 1: return Forward.unpack(key) return arg @staticmethod def eval(arg): # from "{something}", parse as etc form. if arg is None: raise ValueError("None cannot be evaluated!") try: newarg = ast.literal_eval(arg) if type(arg) is str and arg.startswith(("{", "[", "(")) and newarg is not None: if not newarg: raise RuntimeError(f"Cannot eval false object {arg}!") return newarg except ValueError: return arg return arg @staticmethod def isSingleTon( arg): # Very strict. This applies strength to HN, which cannot happen in combined networks. Only weighting is allowed in complex process. if type(arg) is str and not arg.startswith(('[', '(', '{')): # Strict. only accept str return True elif type( arg) is dict: # Strict. only accept {str : int/float} - Strength modification can only happen for str. return len(arg) == 1 and all(type(value) in (int, float) for value in arg.values()) and all( type(k) is str for k in arg) elif type(arg) in (list, set): return len(arg) == 1 and all(type(x) is str for x in arg) elif type(arg) is tuple: return len(arg) == 2 and type(arg[0]) is str and type(arg[1]) in (int, float) return False @staticmethod def parseSingleTon(sequence): # accepts sequence, returns str, float pair. This is Strict. if type(sequence) in (list, dict, set): assert len(sequence) == 1, f"SingularForward only accepts singletons, but given {sequence}!" key = list(sequence)[0] if type(sequence) is dict: assert type(key) is str, f"Strength modification only accepts single Hypernetwork, but given {key}!" return key, sequence[key] else: key = list(key)[0] return key, 1 elif type(sequence) is tuple: assert len(sequence) == 2, f"Tuple with non-couple {sequence} encountered in SingularForward!" assert type( sequence[0]) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence[0]}!" assert type(sequence[1]) in (int, float), f"Strength tuple only accepts Numbers, but given {sequence[1]}!" return sequence[0], sequence[1] else: assert type( sequence) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence}!" return sequence, 1 @staticmethod def isParallel( arg): # Parallel, or Sequential processing is not strict, it can have {"String covered sequence or just HN String" : weight, ... if type(arg) in (dict, set) and len(arg) > 1: if type(arg) is set: return all(type(key) is str for key in arg), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" else: arg: dict return all(type(key) is str for key in arg.keys()), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" else: return False @staticmethod def parseParallel(sequence): # accepts sequence, returns {"Name or sequence" : weight...} assert len(sequence) > 1, f"Length of sequence {sequence} was not enough for parallel!" if type(sequence) is set: # only allows hashable types. otherwise it should be supplied as string cover assert all(type(key) in (str, tuple) for key in sequence), f"All keys should be Hypernetwork Name/Sequence for Set but given :{sequence}" return {key: 1 / len(sequence) for key in sequence} elif type(sequence) is dict: assert all(type(key) in (str, tuple) for key in sequence.keys()), f"All keys should be Hypernetwork Name/Sequence for Dict but given :{sequence}" assert all(type(value) in (int, float) for value in sequence.values()), f"All values should be int/float for Dict but given :{sequence}" return sequence else: raise ValueError(f"Cannot parse parallel sequence {sequence}!") @staticmethod def isSequential(arg): if type(arg) is list and len(arg) > 0: return True return False @staticmethod def parseSequential(sequence): # accepts sequence, only checks if its list, then returns sequence. if type(sequence) is list and len(sequence) > 0: return sequence else: raise ValueError(f"Cannot parse non-list sequence {sequence}!") def shorthash(self): return '0000000000' from .hypernetwork import Hypernetwork def find_non_hash_key(target): closest = [x for x in shared.hypernetworks if x.rsplit('(', 1)[0] == target or x == target] if closest: return shared.hypernetworks[closest[0]] raise KeyError(f"{target} is not found in Hypernetworks!") class SingularForward(Forward): def __init__(self, processor, strength): assert processor != 'defaultForward', "Cannot use name defaultForward!" super(SingularForward, self).__init__() self.name = processor self.processor = processor self.strength = strength # parse. We expect parsing Singletons or (k,v) pair here, which is HN Name and Strength. hn = Hypernetwork() try: hn.load(find_non_hash_key(self.processor)) except: global lazy_load lazy_load = True print("Encountered CUDA Memory Error, will unload HNs, speed might go down severely!") hn.load(find_non_hash_key(self.processor)) available_opts[self.processor] = hn # assert self.processor in available_opts, f"Hypernetwork named {processor} is not ready!" assert 0 <= self.strength <= 1, "Strength must be between 0 and 1!" print(f"SingularForward <{self.name}, {self.strength}>") def __call__(self, context_k, context_v=None, layer=None): if self.processor in available_opts: context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None) if context_v is None: context_v = context_k if context_layers is None: return context_k, context_v #if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): # layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1] return devices.cond_cast_unet(context_layers[0](devices.cond_cast_float(context_k), multiplier=self.strength)),\ devices.cond_cast_unet(context_layers[1](devices.cond_cast_float(context_v), multiplier=self.strength)) # define forward_strength, which invokes HNModule with specified strength. # Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure. raise KeyError(f"Key {self.processor} is not found in cached Hypernetworks!") def __str__(self): return "SingularForward>" + str(self.processor) class ParallelForward(Forward): def __init__(self, sequence, name=None): self.name = "ParallelForwardHypernet" if name is None else name self.callers = {} self.weights = {} super(ParallelForward, self).__init__() # parse for keys in sequence: self.callers[keys] = Forward.parse(keys) self.weights[keys] = sequence[keys] / sum(sequence.values()) print(str(self)) def __call__(self, context, context_v=None, layer=None): ctx_k, ctx_v = torch.zeros_like(context, device=context.device), torch.zeros_like(context, device=context.device) for key in self.callers: k, v = self.callers[key](context, context_v, layer=layer) ctx_k += k * self.weights[key] ctx_v += v * self.weights[key] return ctx_k, ctx_v def __str__(self): return "ParallelForward>" + str({str(k): str(v) for (k, v) in self.callers.items()}) class SequentialForward(Forward): def __init__(self, sequence, name=None): self.name = "SequentialForwardHypernet" if name is None else name self.callers = [] super(SequentialForward, self).__init__() for keys in sequence: self.callers.append(Forward.parse(keys)) print(str(self)) def __call__(self, context, context_v=None, layer=None): if context_v is None: context_v = context for keys in self.callers: context, context_v = keys(context, context_v, layer=layer) return context, context_v def __str__(self): return "SequentialForward>" + str([str(x) for x in self.callers]) class EmptyForward(Forward): def __init__(self): super().__init__() self.name = None def __call__(self, context, context_v=None, layer=None): if context_v is None: context_v = context return context, context_v def __str__(self): return "EmptyForward" def load(filename): with open(filename, 'r') as file: return Forward.parse(file.read(), name=os.path.basename(filename))