File size: 13,585 Bytes
ef9fd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import ast
import os.path

import torch

from modules import devices, shared
from .hnutil import find_self
from .shared import version_flag

lazy_load = False  # when this is enabled, HNs will be loaded when required.
if not hasattr(devices, 'cond_cast_unet'):
    raise RuntimeError("Cannot find cond_cast_unet attribute, please update your webui version!")


class DynamicDict(dict):  # Brief dict that dynamically unloads Hypernetworks if required.
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.current = None
        self.hash = None
        self.dict = {**kwargs}

    def prepare(self, key, value):
        if lazy_load and self.current is not None and (
                key != self.current):  # or filename is identical, but somehow hash is changed?
            self.current.to('cpu')
        self.current = value
        if self.current is not None:
            self.current.to(devices.device)

    def __getitem__(self, item):
        value = self.dict[item]
        self.prepare(item, value)
        return value

    def __setitem__(self, key, value):
        if key in self.dict:
            return
        self.dict[key] = value

    def __contains__(self, item):
        return item in self.dict


available_opts = DynamicDict()  # string -> HN itself.


# Behavior definition.
# [[], [], []] -> sequential processing
# [{"A" : 0.8, "B" : 0.1}] -> parallel processing. with weighted sum in this case, A = 8/9 effect, B = 1/9 effect.
# [("A", 0.2), ("B", 0.4)] -> tuple is used to specify strength.
# [{"A", "B", "C"}] -> parallel, but having same effects (set)
# ["A", "B", []] -> sequential processing
# [{"A":0.6}, "B", "C"] -> sequential, dict with single value will be considered as strength modification.
# [["A"], {"B"}, "C"] -> singletons are equal to items without covers, nested singleton will not be parsed, because its inefficient.
# {{'Aa' : 0.2, 'Ab' : 0.8} : 0.8, 'B' : 0.1} (X) -> {"{'Aa' : 0.2, 'Ab' : 0.8}" : 0.8, 'B' : 0.1} (O), When you want complex setups in parallel, you need to cover them with "". You can use backslash too.


# Testing parsing function.

def test_parsing(string=None):
    def test(arg):
        print(arg)
        try:
            obj = str(Forward.parse(arg))
            print(obj)
        except Exception as e:
            print(e)

    if string:
        test(string)
    else:
        for strings in ["[[], [], []]", "[{\"A\" : 0.8, \"B\" : 0.1}]", '[("A", 0.2), ("B", 0.4)]', '[{"A", "B", "C"}]',
                        '[{"A":0.6}, "B", "C"]', '[["A"], {"B"}, "C"]',
                        '{"{\'Aa\' : 0.2, \'Ab\' : 0.8}" : 0.8, \'B\' : 0.1}']:
            test(strings)


class Forward:
    def __init__(self, **kwargs):
        self.name = "defaultForward" if 'name' not in kwargs else kwargs['name']
        pass

    def __call__(self, *args, **kwargs):
        raise NotImplementedError

    def set_multiplier(self, *args, **kwargs):
        pass

    def extra_name(self):
        if version_flag:
            return ""
        found = find_self(self)
        if found is not None:
            return f" <hypernet:{found}:1.0>"
        return f" <hypernet:{self.name}:1.0>"

    @staticmethod
    def parse(arg, name=None):
        arg = Forward.unpack(arg)
        arg = Forward.eval(arg)
        if Forward.isSingleTon(arg):
            return SingularForward(*Forward.parseSingleTon(arg))
        elif Forward.isParallel(arg):
            return ParallelForward(Forward.parseParallel(arg), name=name)
        elif Forward.isSequential(arg):
            return SequentialForward(Forward.parseSequential(arg), name=name)
        raise ValueError(f"Cannot parse {arg} into sequences!")

    @staticmethod
    def unpack(arg):  # stop using ({({{((a))}})}) please
        if len(arg) == 1 and type(arg) in (set, list, tuple):
            return Forward.unpack(list(arg)[0])
        if len(arg) == 1 and type(arg) is dict:
            key = list(arg.keys())[0]
            if arg[key] == 1:
                return Forward.unpack(key)
        return arg

    @staticmethod
    def eval(arg):  # from "{something}", parse as etc form.
        if arg is None:
            raise ValueError("None cannot be evaluated!")
        try:
            newarg = ast.literal_eval(arg)
            if type(arg) is str and arg.startswith(("{", "[", "(")) and newarg is not None:
                if not newarg:
                    raise RuntimeError(f"Cannot eval false object {arg}!")
                return newarg
        except ValueError:
            return arg
        return arg

    @staticmethod
    def isSingleTon(
            arg):  # Very strict. This applies strength to HN, which cannot happen in combined networks. Only weighting is allowed in complex process.
        if type(arg) is str and not arg.startswith(('[', '(', '{')):  # Strict. only accept str
            return True
        elif type(
                arg) is dict:  # Strict. only accept {str : int/float} - Strength modification can only happen for str.
            return len(arg) == 1 and all(type(value) in (int, float) for value in arg.values()) and all(
                type(k) is str for k in arg)
        elif type(arg) in (list, set):
            return len(arg) == 1 and all(type(x) is str for x in arg)
        elif type(arg) is tuple:
            return len(arg) == 2 and type(arg[0]) is str and type(arg[1]) in (int, float)
        return False

    @staticmethod
    def parseSingleTon(sequence):  # accepts sequence, returns str, float pair. This is Strict.
        if type(sequence) in (list, dict, set):
            assert len(sequence) == 1, f"SingularForward only accepts singletons, but given {sequence}!"
            key = list(sequence)[0]
            if type(sequence) is dict:
                assert type(key) is str, f"Strength modification only accepts single Hypernetwork, but given {key}!"
                return key, sequence[key]
            else:
                key = list(key)[0]
                return key, 1
        elif type(sequence) is tuple:
            assert len(sequence) == 2, f"Tuple with non-couple {sequence} encountered in SingularForward!"
            assert type(
                sequence[0]) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence[0]}!"
            assert type(sequence[1]) in (int, float), f"Strength tuple only accepts Numbers, but given {sequence[1]}!"
            return sequence[0], sequence[1]
        else:
            assert type(
                sequence) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence}!"
            return sequence, 1

    @staticmethod
    def isParallel(
            arg):  # Parallel, or Sequential processing is not strict, it can have {"String covered sequence or just HN String" : weight, ...
        if type(arg) in (dict, set) and len(arg) > 1:
            if type(arg) is set:
                return all(type(key) is str for key in
                           arg), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}"
            else:
                arg: dict
                return all(type(key) is str for key in
                           arg.keys()), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}"
        else:
            return False

    @staticmethod
    def parseParallel(sequence):  # accepts sequence, returns {"Name or sequence" : weight...}
        assert len(sequence) > 1, f"Length of sequence {sequence} was not enough for parallel!"
        if type(sequence) is set:  # only allows hashable types. otherwise it should be supplied as string cover
            assert all(type(key) in (str, tuple) for key in
                       sequence), f"All keys should be Hypernetwork Name/Sequence for Set but given :{sequence}"
            return {key: 1 / len(sequence) for key in sequence}
        elif type(sequence) is dict:
            assert all(type(key) in (str, tuple) for key in
                       sequence.keys()), f"All keys should be Hypernetwork Name/Sequence for Dict but given :{sequence}"
            assert all(type(value) in (int, float) for value in
                       sequence.values()), f"All values should be int/float for Dict but given :{sequence}"
            return sequence
        else:
            raise ValueError(f"Cannot parse parallel sequence {sequence}!")

    @staticmethod
    def isSequential(arg):
        if type(arg) is list and len(arg) > 0:
            return True
        return False

    @staticmethod
    def parseSequential(sequence):  # accepts sequence, only checks if its list, then returns sequence.
        if type(sequence) is list and len(sequence) > 0:
            return sequence
        else:
            raise ValueError(f"Cannot parse non-list sequence {sequence}!")

    def shorthash(self):
        return '0000000000'

from .hypernetwork import Hypernetwork


def find_non_hash_key(target):
    closest = [x for x in shared.hypernetworks if x.rsplit('(', 1)[0] == target or x == target]
    if closest:
        return shared.hypernetworks[closest[0]]
    raise KeyError(f"{target} is not found in Hypernetworks!")


class SingularForward(Forward):

    def __init__(self, processor, strength):
        assert processor != 'defaultForward', "Cannot use name defaultForward!"
        super(SingularForward, self).__init__()
        self.name = processor
        self.processor = processor
        self.strength = strength
        # parse. We expect parsing Singletons or (k,v) pair here, which is HN Name and Strength.
        hn = Hypernetwork()
        try:
            hn.load(find_non_hash_key(self.processor))
        except:
            global lazy_load
            lazy_load = True
            print("Encountered CUDA Memory Error, will unload HNs, speed might go down severely!")
            hn.load(find_non_hash_key(self.processor))
        available_opts[self.processor] = hn
        # assert self.processor in available_opts, f"Hypernetwork named {processor} is not ready!"
        assert 0 <= self.strength <= 1, "Strength must be between 0 and 1!"
        print(f"SingularForward <{self.name}, {self.strength}>")

    def __call__(self, context_k, context_v=None, layer=None):
        if self.processor in available_opts:
            context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None)
            if context_v is None:
                context_v = context_k
            if context_layers is None:
                return context_k, context_v
            #if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'):
            #    layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1]
            return devices.cond_cast_unet(context_layers[0](devices.cond_cast_float(context_k), multiplier=self.strength)),\
                   devices.cond_cast_unet(context_layers[1](devices.cond_cast_float(context_v), multiplier=self.strength))
            # define forward_strength, which invokes HNModule with specified strength.
        # Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure.
        raise KeyError(f"Key {self.processor} is not found in cached Hypernetworks!")

    def __str__(self):
        return "SingularForward>" + str(self.processor)


class ParallelForward(Forward):

    def __init__(self, sequence, name=None):
        self.name = "ParallelForwardHypernet" if name is None else name
        self.callers = {}
        self.weights = {}
        super(ParallelForward, self).__init__()
        # parse
        for keys in sequence:
            self.callers[keys] = Forward.parse(keys)
            self.weights[keys] = sequence[keys] / sum(sequence.values())
        print(str(self))

    def __call__(self, context, context_v=None, layer=None):
        ctx_k, ctx_v = torch.zeros_like(context, device=context.device), torch.zeros_like(context,
                                                                                          device=context.device)
        for key in self.callers:
            k, v = self.callers[key](context, context_v, layer=layer)
            ctx_k += k * self.weights[key]
            ctx_v += v * self.weights[key]
        return ctx_k, ctx_v

    def __str__(self):
        return "ParallelForward>" + str({str(k): str(v) for (k, v) in self.callers.items()})


class SequentialForward(Forward):
    def __init__(self, sequence, name=None):
        self.name = "SequentialForwardHypernet" if name is None else name
        self.callers = []
        super(SequentialForward, self).__init__()
        for keys in sequence:
            self.callers.append(Forward.parse(keys))
        print(str(self))

    def __call__(self, context, context_v=None, layer=None):
        if context_v is None:
            context_v = context
        for keys in self.callers:
            context, context_v = keys(context, context_v, layer=layer)
        return context, context_v

    def __str__(self):
        return "SequentialForward>" + str([str(x) for x in self.callers])


class EmptyForward(Forward):
    def __init__(self):
        super().__init__()
        self.name = None

    def __call__(self, context, context_v=None, layer=None):
        if context_v is None:
            context_v = context
        return context, context_v

    def __str__(self):
        return "EmptyForward"


def load(filename):
    with open(filename, 'r') as file:
        return Forward.parse(file.read(), name=os.path.basename(filename))