Spaces:
Runtime error
Runtime error
| import ast | |
| import os.path | |
| import torch | |
| from modules import devices, shared | |
| from .hnutil import find_self | |
| from .shared import version_flag | |
| lazy_load = False # when this is enabled, HNs will be loaded when required. | |
| if not hasattr(devices, 'cond_cast_unet'): | |
| raise RuntimeError("Cannot find cond_cast_unet attribute, please update your webui version!") | |
| class DynamicDict(dict): # Brief dict that dynamically unloads Hypernetworks if required. | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.current = None | |
| self.hash = None | |
| self.dict = {**kwargs} | |
| def prepare(self, key, value): | |
| if lazy_load and self.current is not None and ( | |
| key != self.current): # or filename is identical, but somehow hash is changed? | |
| self.current.to('cpu') | |
| self.current = value | |
| if self.current is not None: | |
| self.current.to(devices.device) | |
| def __getitem__(self, item): | |
| value = self.dict[item] | |
| self.prepare(item, value) | |
| return value | |
| def __setitem__(self, key, value): | |
| if key in self.dict: | |
| return | |
| self.dict[key] = value | |
| def __contains__(self, item): | |
| return item in self.dict | |
| available_opts = DynamicDict() # string -> HN itself. | |
| # Behavior definition. | |
| # [[], [], []] -> sequential processing | |
| # [{"A" : 0.8, "B" : 0.1}] -> parallel processing. with weighted sum in this case, A = 8/9 effect, B = 1/9 effect. | |
| # [("A", 0.2), ("B", 0.4)] -> tuple is used to specify strength. | |
| # [{"A", "B", "C"}] -> parallel, but having same effects (set) | |
| # ["A", "B", []] -> sequential processing | |
| # [{"A":0.6}, "B", "C"] -> sequential, dict with single value will be considered as strength modification. | |
| # [["A"], {"B"}, "C"] -> singletons are equal to items without covers, nested singleton will not be parsed, because its inefficient. | |
| # {{'Aa' : 0.2, 'Ab' : 0.8} : 0.8, 'B' : 0.1} (X) -> {"{'Aa' : 0.2, 'Ab' : 0.8}" : 0.8, 'B' : 0.1} (O), When you want complex setups in parallel, you need to cover them with "". You can use backslash too. | |
| # Testing parsing function. | |
| def test_parsing(string=None): | |
| def test(arg): | |
| print(arg) | |
| try: | |
| obj = str(Forward.parse(arg)) | |
| print(obj) | |
| except Exception as e: | |
| print(e) | |
| if string: | |
| test(string) | |
| else: | |
| for strings in ["[[], [], []]", "[{\"A\" : 0.8, \"B\" : 0.1}]", '[("A", 0.2), ("B", 0.4)]', '[{"A", "B", "C"}]', | |
| '[{"A":0.6}, "B", "C"]', '[["A"], {"B"}, "C"]', | |
| '{"{\'Aa\' : 0.2, \'Ab\' : 0.8}" : 0.8, \'B\' : 0.1}']: | |
| test(strings) | |
| class Forward: | |
| def __init__(self, **kwargs): | |
| self.name = "defaultForward" if 'name' not in kwargs else kwargs['name'] | |
| pass | |
| def __call__(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def set_multiplier(self, *args, **kwargs): | |
| pass | |
| def extra_name(self): | |
| if version_flag: | |
| return "" | |
| found = find_self(self) | |
| if found is not None: | |
| return f" <hypernet:{found}:1.0>" | |
| return f" <hypernet:{self.name}:1.0>" | |
| def parse(arg, name=None): | |
| arg = Forward.unpack(arg) | |
| arg = Forward.eval(arg) | |
| if Forward.isSingleTon(arg): | |
| return SingularForward(*Forward.parseSingleTon(arg)) | |
| elif Forward.isParallel(arg): | |
| return ParallelForward(Forward.parseParallel(arg), name=name) | |
| elif Forward.isSequential(arg): | |
| return SequentialForward(Forward.parseSequential(arg), name=name) | |
| raise ValueError(f"Cannot parse {arg} into sequences!") | |
| def unpack(arg): # stop using ({({{((a))}})}) please | |
| if len(arg) == 1 and type(arg) in (set, list, tuple): | |
| return Forward.unpack(list(arg)[0]) | |
| if len(arg) == 1 and type(arg) is dict: | |
| key = list(arg.keys())[0] | |
| if arg[key] == 1: | |
| return Forward.unpack(key) | |
| return arg | |
| def eval(arg): # from "{something}", parse as etc form. | |
| if arg is None: | |
| raise ValueError("None cannot be evaluated!") | |
| try: | |
| newarg = ast.literal_eval(arg) | |
| if type(arg) is str and arg.startswith(("{", "[", "(")) and newarg is not None: | |
| if not newarg: | |
| raise RuntimeError(f"Cannot eval false object {arg}!") | |
| return newarg | |
| except ValueError: | |
| return arg | |
| return arg | |
| def isSingleTon( | |
| arg): # Very strict. This applies strength to HN, which cannot happen in combined networks. Only weighting is allowed in complex process. | |
| if type(arg) is str and not arg.startswith(('[', '(', '{')): # Strict. only accept str | |
| return True | |
| elif type( | |
| arg) is dict: # Strict. only accept {str : int/float} - Strength modification can only happen for str. | |
| return len(arg) == 1 and all(type(value) in (int, float) for value in arg.values()) and all( | |
| type(k) is str for k in arg) | |
| elif type(arg) in (list, set): | |
| return len(arg) == 1 and all(type(x) is str for x in arg) | |
| elif type(arg) is tuple: | |
| return len(arg) == 2 and type(arg[0]) is str and type(arg[1]) in (int, float) | |
| return False | |
| def parseSingleTon(sequence): # accepts sequence, returns str, float pair. This is Strict. | |
| if type(sequence) in (list, dict, set): | |
| assert len(sequence) == 1, f"SingularForward only accepts singletons, but given {sequence}!" | |
| key = list(sequence)[0] | |
| if type(sequence) is dict: | |
| assert type(key) is str, f"Strength modification only accepts single Hypernetwork, but given {key}!" | |
| return key, sequence[key] | |
| else: | |
| key = list(key)[0] | |
| return key, 1 | |
| elif type(sequence) is tuple: | |
| assert len(sequence) == 2, f"Tuple with non-couple {sequence} encountered in SingularForward!" | |
| assert type( | |
| sequence[0]) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence[0]}!" | |
| assert type(sequence[1]) in (int, float), f"Strength tuple only accepts Numbers, but given {sequence[1]}!" | |
| return sequence[0], sequence[1] | |
| else: | |
| assert type( | |
| sequence) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence}!" | |
| return sequence, 1 | |
| def isParallel( | |
| arg): # Parallel, or Sequential processing is not strict, it can have {"String covered sequence or just HN String" : weight, ... | |
| if type(arg) in (dict, set) and len(arg) > 1: | |
| if type(arg) is set: | |
| return all(type(key) is str for key in | |
| arg), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" | |
| else: | |
| arg: dict | |
| return all(type(key) is str for key in | |
| arg.keys()), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" | |
| else: | |
| return False | |
| def parseParallel(sequence): # accepts sequence, returns {"Name or sequence" : weight...} | |
| assert len(sequence) > 1, f"Length of sequence {sequence} was not enough for parallel!" | |
| if type(sequence) is set: # only allows hashable types. otherwise it should be supplied as string cover | |
| assert all(type(key) in (str, tuple) for key in | |
| sequence), f"All keys should be Hypernetwork Name/Sequence for Set but given :{sequence}" | |
| return {key: 1 / len(sequence) for key in sequence} | |
| elif type(sequence) is dict: | |
| assert all(type(key) in (str, tuple) for key in | |
| sequence.keys()), f"All keys should be Hypernetwork Name/Sequence for Dict but given :{sequence}" | |
| assert all(type(value) in (int, float) for value in | |
| sequence.values()), f"All values should be int/float for Dict but given :{sequence}" | |
| return sequence | |
| else: | |
| raise ValueError(f"Cannot parse parallel sequence {sequence}!") | |
| def isSequential(arg): | |
| if type(arg) is list and len(arg) > 0: | |
| return True | |
| return False | |
| def parseSequential(sequence): # accepts sequence, only checks if its list, then returns sequence. | |
| if type(sequence) is list and len(sequence) > 0: | |
| return sequence | |
| else: | |
| raise ValueError(f"Cannot parse non-list sequence {sequence}!") | |
| def shorthash(self): | |
| return '0000000000' | |
| from .hypernetwork import Hypernetwork | |
| def find_non_hash_key(target): | |
| closest = [x for x in shared.hypernetworks if x.rsplit('(', 1)[0] == target or x == target] | |
| if closest: | |
| return shared.hypernetworks[closest[0]] | |
| raise KeyError(f"{target} is not found in Hypernetworks!") | |
| class SingularForward(Forward): | |
| def __init__(self, processor, strength): | |
| assert processor != 'defaultForward', "Cannot use name defaultForward!" | |
| super(SingularForward, self).__init__() | |
| self.name = processor | |
| self.processor = processor | |
| self.strength = strength | |
| # parse. We expect parsing Singletons or (k,v) pair here, which is HN Name and Strength. | |
| hn = Hypernetwork() | |
| try: | |
| hn.load(find_non_hash_key(self.processor)) | |
| except: | |
| global lazy_load | |
| lazy_load = True | |
| print("Encountered CUDA Memory Error, will unload HNs, speed might go down severely!") | |
| hn.load(find_non_hash_key(self.processor)) | |
| available_opts[self.processor] = hn | |
| # assert self.processor in available_opts, f"Hypernetwork named {processor} is not ready!" | |
| assert 0 <= self.strength <= 1, "Strength must be between 0 and 1!" | |
| print(f"SingularForward <{self.name}, {self.strength}>") | |
| def __call__(self, context_k, context_v=None, layer=None): | |
| if self.processor in available_opts: | |
| context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None) | |
| if context_v is None: | |
| context_v = context_k | |
| if context_layers is None: | |
| return context_k, context_v | |
| #if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): | |
| # layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1] | |
| return devices.cond_cast_unet(context_layers[0](devices.cond_cast_float(context_k), multiplier=self.strength)),\ | |
| devices.cond_cast_unet(context_layers[1](devices.cond_cast_float(context_v), multiplier=self.strength)) | |
| # define forward_strength, which invokes HNModule with specified strength. | |
| # Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure. | |
| raise KeyError(f"Key {self.processor} is not found in cached Hypernetworks!") | |
| def __str__(self): | |
| return "SingularForward>" + str(self.processor) | |
| class ParallelForward(Forward): | |
| def __init__(self, sequence, name=None): | |
| self.name = "ParallelForwardHypernet" if name is None else name | |
| self.callers = {} | |
| self.weights = {} | |
| super(ParallelForward, self).__init__() | |
| # parse | |
| for keys in sequence: | |
| self.callers[keys] = Forward.parse(keys) | |
| self.weights[keys] = sequence[keys] / sum(sequence.values()) | |
| print(str(self)) | |
| def __call__(self, context, context_v=None, layer=None): | |
| ctx_k, ctx_v = torch.zeros_like(context, device=context.device), torch.zeros_like(context, | |
| device=context.device) | |
| for key in self.callers: | |
| k, v = self.callers[key](context, context_v, layer=layer) | |
| ctx_k += k * self.weights[key] | |
| ctx_v += v * self.weights[key] | |
| return ctx_k, ctx_v | |
| def __str__(self): | |
| return "ParallelForward>" + str({str(k): str(v) for (k, v) in self.callers.items()}) | |
| class SequentialForward(Forward): | |
| def __init__(self, sequence, name=None): | |
| self.name = "SequentialForwardHypernet" if name is None else name | |
| self.callers = [] | |
| super(SequentialForward, self).__init__() | |
| for keys in sequence: | |
| self.callers.append(Forward.parse(keys)) | |
| print(str(self)) | |
| def __call__(self, context, context_v=None, layer=None): | |
| if context_v is None: | |
| context_v = context | |
| for keys in self.callers: | |
| context, context_v = keys(context, context_v, layer=layer) | |
| return context, context_v | |
| def __str__(self): | |
| return "SequentialForward>" + str([str(x) for x in self.callers]) | |
| class EmptyForward(Forward): | |
| def __init__(self): | |
| super().__init__() | |
| self.name = None | |
| def __call__(self, context, context_v=None, layer=None): | |
| if context_v is None: | |
| context_v = context | |
| return context, context_v | |
| def __str__(self): | |
| return "EmptyForward" | |
| def load(filename): | |
| with open(filename, 'r') as file: | |
| return Forward.parse(file.read(), name=os.path.basename(filename)) | |