Spaces:
Runtime error
Runtime error
import ast | |
import os.path | |
import torch | |
from modules import devices, shared | |
from .hnutil import find_self | |
from .shared import version_flag | |
lazy_load = False # when this is enabled, HNs will be loaded when required. | |
if not hasattr(devices, 'cond_cast_unet'): | |
raise RuntimeError("Cannot find cond_cast_unet attribute, please update your webui version!") | |
class DynamicDict(dict): # Brief dict that dynamically unloads Hypernetworks if required. | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.current = None | |
self.hash = None | |
self.dict = {**kwargs} | |
def prepare(self, key, value): | |
if lazy_load and self.current is not None and ( | |
key != self.current): # or filename is identical, but somehow hash is changed? | |
self.current.to('cpu') | |
self.current = value | |
if self.current is not None: | |
self.current.to(devices.device) | |
def __getitem__(self, item): | |
value = self.dict[item] | |
self.prepare(item, value) | |
return value | |
def __setitem__(self, key, value): | |
if key in self.dict: | |
return | |
self.dict[key] = value | |
def __contains__(self, item): | |
return item in self.dict | |
available_opts = DynamicDict() # string -> HN itself. | |
# Behavior definition. | |
# [[], [], []] -> sequential processing | |
# [{"A" : 0.8, "B" : 0.1}] -> parallel processing. with weighted sum in this case, A = 8/9 effect, B = 1/9 effect. | |
# [("A", 0.2), ("B", 0.4)] -> tuple is used to specify strength. | |
# [{"A", "B", "C"}] -> parallel, but having same effects (set) | |
# ["A", "B", []] -> sequential processing | |
# [{"A":0.6}, "B", "C"] -> sequential, dict with single value will be considered as strength modification. | |
# [["A"], {"B"}, "C"] -> singletons are equal to items without covers, nested singleton will not be parsed, because its inefficient. | |
# {{'Aa' : 0.2, 'Ab' : 0.8} : 0.8, 'B' : 0.1} (X) -> {"{'Aa' : 0.2, 'Ab' : 0.8}" : 0.8, 'B' : 0.1} (O), When you want complex setups in parallel, you need to cover them with "". You can use backslash too. | |
# Testing parsing function. | |
def test_parsing(string=None): | |
def test(arg): | |
print(arg) | |
try: | |
obj = str(Forward.parse(arg)) | |
print(obj) | |
except Exception as e: | |
print(e) | |
if string: | |
test(string) | |
else: | |
for strings in ["[[], [], []]", "[{\"A\" : 0.8, \"B\" : 0.1}]", '[("A", 0.2), ("B", 0.4)]', '[{"A", "B", "C"}]', | |
'[{"A":0.6}, "B", "C"]', '[["A"], {"B"}, "C"]', | |
'{"{\'Aa\' : 0.2, \'Ab\' : 0.8}" : 0.8, \'B\' : 0.1}']: | |
test(strings) | |
class Forward: | |
def __init__(self, **kwargs): | |
self.name = "defaultForward" if 'name' not in kwargs else kwargs['name'] | |
pass | |
def __call__(self, *args, **kwargs): | |
raise NotImplementedError | |
def set_multiplier(self, *args, **kwargs): | |
pass | |
def extra_name(self): | |
if version_flag: | |
return "" | |
found = find_self(self) | |
if found is not None: | |
return f" <hypernet:{found}:1.0>" | |
return f" <hypernet:{self.name}:1.0>" | |
def parse(arg, name=None): | |
arg = Forward.unpack(arg) | |
arg = Forward.eval(arg) | |
if Forward.isSingleTon(arg): | |
return SingularForward(*Forward.parseSingleTon(arg)) | |
elif Forward.isParallel(arg): | |
return ParallelForward(Forward.parseParallel(arg), name=name) | |
elif Forward.isSequential(arg): | |
return SequentialForward(Forward.parseSequential(arg), name=name) | |
raise ValueError(f"Cannot parse {arg} into sequences!") | |
def unpack(arg): # stop using ({({{((a))}})}) please | |
if len(arg) == 1 and type(arg) in (set, list, tuple): | |
return Forward.unpack(list(arg)[0]) | |
if len(arg) == 1 and type(arg) is dict: | |
key = list(arg.keys())[0] | |
if arg[key] == 1: | |
return Forward.unpack(key) | |
return arg | |
def eval(arg): # from "{something}", parse as etc form. | |
if arg is None: | |
raise ValueError("None cannot be evaluated!") | |
try: | |
newarg = ast.literal_eval(arg) | |
if type(arg) is str and arg.startswith(("{", "[", "(")) and newarg is not None: | |
if not newarg: | |
raise RuntimeError(f"Cannot eval false object {arg}!") | |
return newarg | |
except ValueError: | |
return arg | |
return arg | |
def isSingleTon( | |
arg): # Very strict. This applies strength to HN, which cannot happen in combined networks. Only weighting is allowed in complex process. | |
if type(arg) is str and not arg.startswith(('[', '(', '{')): # Strict. only accept str | |
return True | |
elif type( | |
arg) is dict: # Strict. only accept {str : int/float} - Strength modification can only happen for str. | |
return len(arg) == 1 and all(type(value) in (int, float) for value in arg.values()) and all( | |
type(k) is str for k in arg) | |
elif type(arg) in (list, set): | |
return len(arg) == 1 and all(type(x) is str for x in arg) | |
elif type(arg) is tuple: | |
return len(arg) == 2 and type(arg[0]) is str and type(arg[1]) in (int, float) | |
return False | |
def parseSingleTon(sequence): # accepts sequence, returns str, float pair. This is Strict. | |
if type(sequence) in (list, dict, set): | |
assert len(sequence) == 1, f"SingularForward only accepts singletons, but given {sequence}!" | |
key = list(sequence)[0] | |
if type(sequence) is dict: | |
assert type(key) is str, f"Strength modification only accepts single Hypernetwork, but given {key}!" | |
return key, sequence[key] | |
else: | |
key = list(key)[0] | |
return key, 1 | |
elif type(sequence) is tuple: | |
assert len(sequence) == 2, f"Tuple with non-couple {sequence} encountered in SingularForward!" | |
assert type( | |
sequence[0]) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence[0]}!" | |
assert type(sequence[1]) in (int, float), f"Strength tuple only accepts Numbers, but given {sequence[1]}!" | |
return sequence[0], sequence[1] | |
else: | |
assert type( | |
sequence) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence}!" | |
return sequence, 1 | |
def isParallel( | |
arg): # Parallel, or Sequential processing is not strict, it can have {"String covered sequence or just HN String" : weight, ... | |
if type(arg) in (dict, set) and len(arg) > 1: | |
if type(arg) is set: | |
return all(type(key) is str for key in | |
arg), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" | |
else: | |
arg: dict | |
return all(type(key) is str for key in | |
arg.keys()), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" | |
else: | |
return False | |
def parseParallel(sequence): # accepts sequence, returns {"Name or sequence" : weight...} | |
assert len(sequence) > 1, f"Length of sequence {sequence} was not enough for parallel!" | |
if type(sequence) is set: # only allows hashable types. otherwise it should be supplied as string cover | |
assert all(type(key) in (str, tuple) for key in | |
sequence), f"All keys should be Hypernetwork Name/Sequence for Set but given :{sequence}" | |
return {key: 1 / len(sequence) for key in sequence} | |
elif type(sequence) is dict: | |
assert all(type(key) in (str, tuple) for key in | |
sequence.keys()), f"All keys should be Hypernetwork Name/Sequence for Dict but given :{sequence}" | |
assert all(type(value) in (int, float) for value in | |
sequence.values()), f"All values should be int/float for Dict but given :{sequence}" | |
return sequence | |
else: | |
raise ValueError(f"Cannot parse parallel sequence {sequence}!") | |
def isSequential(arg): | |
if type(arg) is list and len(arg) > 0: | |
return True | |
return False | |
def parseSequential(sequence): # accepts sequence, only checks if its list, then returns sequence. | |
if type(sequence) is list and len(sequence) > 0: | |
return sequence | |
else: | |
raise ValueError(f"Cannot parse non-list sequence {sequence}!") | |
def shorthash(self): | |
return '0000000000' | |
from .hypernetwork import Hypernetwork | |
def find_non_hash_key(target): | |
closest = [x for x in shared.hypernetworks if x.rsplit('(', 1)[0] == target or x == target] | |
if closest: | |
return shared.hypernetworks[closest[0]] | |
raise KeyError(f"{target} is not found in Hypernetworks!") | |
class SingularForward(Forward): | |
def __init__(self, processor, strength): | |
assert processor != 'defaultForward', "Cannot use name defaultForward!" | |
super(SingularForward, self).__init__() | |
self.name = processor | |
self.processor = processor | |
self.strength = strength | |
# parse. We expect parsing Singletons or (k,v) pair here, which is HN Name and Strength. | |
hn = Hypernetwork() | |
try: | |
hn.load(find_non_hash_key(self.processor)) | |
except: | |
global lazy_load | |
lazy_load = True | |
print("Encountered CUDA Memory Error, will unload HNs, speed might go down severely!") | |
hn.load(find_non_hash_key(self.processor)) | |
available_opts[self.processor] = hn | |
# assert self.processor in available_opts, f"Hypernetwork named {processor} is not ready!" | |
assert 0 <= self.strength <= 1, "Strength must be between 0 and 1!" | |
print(f"SingularForward <{self.name}, {self.strength}>") | |
def __call__(self, context_k, context_v=None, layer=None): | |
if self.processor in available_opts: | |
context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None) | |
if context_v is None: | |
context_v = context_k | |
if context_layers is None: | |
return context_k, context_v | |
#if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): | |
# layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1] | |
return devices.cond_cast_unet(context_layers[0](devices.cond_cast_float(context_k), multiplier=self.strength)),\ | |
devices.cond_cast_unet(context_layers[1](devices.cond_cast_float(context_v), multiplier=self.strength)) | |
# define forward_strength, which invokes HNModule with specified strength. | |
# Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure. | |
raise KeyError(f"Key {self.processor} is not found in cached Hypernetworks!") | |
def __str__(self): | |
return "SingularForward>" + str(self.processor) | |
class ParallelForward(Forward): | |
def __init__(self, sequence, name=None): | |
self.name = "ParallelForwardHypernet" if name is None else name | |
self.callers = {} | |
self.weights = {} | |
super(ParallelForward, self).__init__() | |
# parse | |
for keys in sequence: | |
self.callers[keys] = Forward.parse(keys) | |
self.weights[keys] = sequence[keys] / sum(sequence.values()) | |
print(str(self)) | |
def __call__(self, context, context_v=None, layer=None): | |
ctx_k, ctx_v = torch.zeros_like(context, device=context.device), torch.zeros_like(context, | |
device=context.device) | |
for key in self.callers: | |
k, v = self.callers[key](context, context_v, layer=layer) | |
ctx_k += k * self.weights[key] | |
ctx_v += v * self.weights[key] | |
return ctx_k, ctx_v | |
def __str__(self): | |
return "ParallelForward>" + str({str(k): str(v) for (k, v) in self.callers.items()}) | |
class SequentialForward(Forward): | |
def __init__(self, sequence, name=None): | |
self.name = "SequentialForwardHypernet" if name is None else name | |
self.callers = [] | |
super(SequentialForward, self).__init__() | |
for keys in sequence: | |
self.callers.append(Forward.parse(keys)) | |
print(str(self)) | |
def __call__(self, context, context_v=None, layer=None): | |
if context_v is None: | |
context_v = context | |
for keys in self.callers: | |
context, context_v = keys(context, context_v, layer=layer) | |
return context, context_v | |
def __str__(self): | |
return "SequentialForward>" + str([str(x) for x in self.callers]) | |
class EmptyForward(Forward): | |
def __init__(self): | |
super().__init__() | |
self.name = None | |
def __call__(self, context, context_v=None, layer=None): | |
if context_v is None: | |
context_v = context | |
return context, context_v | |
def __str__(self): | |
return "EmptyForward" | |
def load(filename): | |
with open(filename, 'r') as file: | |
return Forward.parse(file.read(), name=os.path.basename(filename)) | |