File size: 17,291 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
from typing import List, Tuple, Dict, Callable


def preparse_loras_multipliers(loras_multipliers):
    if isinstance(loras_multipliers, list):
        return [multi.strip(" \r\n") if isinstance(multi, str) else multi for multi in loras_multipliers]

    loras_multipliers = loras_multipliers.strip(" \r\n")
    loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n")
    loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
    loras_multipliers = " ".join(loras_mult_choices_list)
    return loras_multipliers.replace("|"," ").split(" ")

def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ):
    def expand_one(slist, num_inference_steps):
        if not isinstance(slist, list): slist = [slist]
        new_slist= []
        if num_inference_steps <=0:
            return new_slist
        inc =  len(slist) / num_inference_steps 
        pos = 0
        for i in range(num_inference_steps):
            new_slist.append(slist[ int(pos)])
            pos += inc
        return new_slist

    phase1 = slists_dict["phase1"][mult_no]
    phase2 = slists_dict["phase2"][mult_no]
    phase3 = slists_dict["phase3"][mult_no]
    shared = slists_dict["shared"][mult_no]
    if shared:
        if isinstance(phase1, float): return phase1
        return expand_one(phase1, num_inference_steps)    
    else:
        if isinstance(phase1, float) and isinstance(phase2, float) and isinstance(phase3, float) and phase1 == phase2 and phase2 == phase3: return phase1 
        return expand_one(phase1, model_switch_step) + expand_one(phase2, model_switch_step2 - model_switch_step) + expand_one(phase3, num_inference_steps - model_switch_step2)

def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, nb_phases = 2, model_switch_step = None, model_switch_step2 = None):
    if "|" in loras_multipliers: 
        pos = loras_multipliers.find("|")
        if "|" in  loras_multipliers[pos+1:]: return "", "", "There can be only one '|' character in Loras Multipliers Sequence"

    if model_switch_step is None:
        model_switch_step = num_inference_steps
    if model_switch_step2 is None:
        model_switch_step2 = num_inference_steps
    def is_float(element: any) -> bool:
        if element is None: 
            return False
        try:
            float(element)
            return True
        except ValueError:
            return False
    loras_list_mult_choices_nums = []
    slists_dict = { "model_switch_step": model_switch_step}
    slists_dict = { "model_switch_step2": model_switch_step2}
    slists_dict["phase1"] = phase1 = [1.] * nb_loras
    slists_dict["phase2"] = phase2 = [1.] * nb_loras
    slists_dict["phase3"] = phase3 = [1.] * nb_loras
    slists_dict["shared"] = shared = [False] * nb_loras

    if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0:
        list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras]
        for i, mult in enumerate(list_mult_choices_list):
            current_phase = phase1
            if isinstance(mult, str):
                mult = mult.strip()
                phase_mult = mult.split(";")
                shared_phases = len(phase_mult) <=1
                if not shared_phases and len(phase_mult) != nb_phases :
                    return "", "", f"if the ';' syntax is used for one Lora multiplier, the multipliers for its {nb_phases} denoising phases should be specified for this multiplier"
                for phase_no, mult in enumerate(phase_mult):
                    if phase_no == 1: 
                        current_phase = phase2
                    elif phase_no == 2: 
                        current_phase = phase3
                    if "," in mult:
                        multlist = mult.split(",")
                        slist = []
                        for smult in multlist:
                            if not is_float(smult):                
                                return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid in Phase {phase_no+1}"
                            slist.append(float(smult))
                    else:
                        if not is_float(mult):                
                            return "", "", f"Lora Multiplier no {i+1} ({mult}) is invalid"
                        slist = float(mult)
                    if shared_phases:
                        phase1[i] = phase2[i] = phase3[i] = slist
                        shared[i] = True
                    else:
                        current_phase[i] = slist
            else:
                phase1[i] = phase2[i] = phase3[i] = float(mult)
                shared[i] = True

    if merge_slist is not None:
        slists_dict["phase1"] = phase1 = merge_slist["phase1"] + phase1
        slists_dict["phase2"] = phase2 = merge_slist["phase2"] + phase2
        slists_dict["phase3"] = phase3 = merge_slist["phase3"] + phase3
        slists_dict["shared"] = shared = merge_slist["shared"] + shared

    loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step, model_switch_step2 )  for i in range(len(phase1)) ]
    loras_list_mult_choices_nums = [ slist[0] if isinstance(slist, list) else slist for slist in loras_list_mult_choices_nums ]
    
    return  loras_list_mult_choices_nums, slists_dict, ""

def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None):
    from mmgp import offload
    sz = len(slists_dict["phase1"])
    slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz)  ]
    nos = [str(l) for l in range(sz)]
    offload.activate_loras(trans, nos, slists ) 



def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
    total_num_steps = len(timesteps)
    model_switch_step = model_switch_step2 = None
    for i, t in enumerate(timesteps):
        if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i
        if guide_phases >=3 and model_switch_step2 is None and t <= switch2_threshold: model_switch_step2 = i                    
    if model_switch_step is None: model_switch_step = total_num_steps
    if model_switch_step2 is None: model_switch_step2 = total_num_steps
    phases_description = ""
    if guide_phases > 1:
        phases_description = "Denoising Steps: "        
        phases_description +=  f" Phase 1 = None" if model_switch_step == 0 else f" Phase 1 = 1:{ min(model_switch_step,total_num_steps) }"
        if model_switch_step < total_num_steps:                    
            phases_description += f", Phase 2 = None" if model_switch_step == model_switch_step2 else f", Phase 2 = {model_switch_step +1}:{ min(model_switch_step2,total_num_steps) }"
            if guide_phases > 2 and model_switch_step2 < total_num_steps:  
                phases_description += f", Phase 3 = {model_switch_step2 +1}:{ total_num_steps}"
    return model_switch_step, model_switch_step2, phases_description



from typing import List, Tuple, Dict, Callable

_ALWD = set(":;,.0123456789")

# ---------------- core parsing helpers ----------------

def _find_bar(s: str) -> int:
    com = False
    for i, ch in enumerate(s):
        if ch in ('\n', '\r'):
            com = False
        elif ch == '#':
            com = True
        elif ch == '|' and not com:
            return i
    return -1

def _spans(text: str) -> List[Tuple[int, int]]:
    res, com, in_tok, st = [], False, False, 0
    for i, ch in enumerate(text):
        if ch in ('\n', '\r'):
            if in_tok: res.append((st, i)); in_tok = False
            com = False
        elif ch == '#':
            if in_tok: res.append((st, i)); in_tok = False
            com = True
        elif not com:
            if ch in _ALWD:
                if not in_tok: in_tok, st = True, i
            else:
                if in_tok: res.append((st, i)); in_tok = False
    if in_tok: res.append((st, len(text)))
    return res

def _choose_sep(text: str, spans: List[Tuple[int, int]]) -> str:
    if len(spans) >= 2:
        a, b = spans[-2][1], spans[-1][0]
        return '\n' if ('\n' in text[a:b] or '\r' in text[a:b]) else ' '
    return '\n' if ('\n' in text or '\r' in text) else ' '

def _ends_in_comment_line(text: str) -> bool:
    ln = text.rfind('\n')
    seg = text[ln + 1:] if ln != -1 else text
    return '#' in seg

def _append_tokens(text: str, k: int, sep: str) -> str:
    if k <= 0: return text
    t = text
    if _ends_in_comment_line(t) and (not t.endswith('\n')): t += '\n'
    parts = []
    if t and not t[-1].isspace(): parts.append(sep)
    parts.append('1')
    for _ in range(k - 1):
        parts.append(sep); parts.append('1')
    return t + ''.join(parts)

def _erase_span_and_one_sep(text: str, st: int, en: int) -> str:
    n = len(text)
    r = en
    while r < n and text[r] in (' ', '\t'): r += 1
    if r > en: return text[:st] + text[r:]
    l = st
    while l > 0 and text[l-1] in (' ', '\t'): l -= 1
    if l < st: return text[:l] + text[en:]
    return text[:st] + text[en:]

def _trim_last_tokens(text: str, spans: List[Tuple[int, int]], drop: int) -> str:
    if drop <= 0: return text
    new_text = text
    for st, en in reversed(spans[-drop:]):
        new_text = _erase_span_and_one_sep(new_text, st, en)
    while new_text and new_text[-1] in (' ', '\t'):
        new_text = new_text[:-1]
    return new_text

def _enforce_count(text: str, target: int) -> str:
    sp = _spans(text); cur = len(sp)
    if cur == target: return text
    if cur > target:  return _trim_last_tokens(text, sp, cur - target)
    sep = _choose_sep(text, sp)
    return _append_tokens(text, target - cur, sep)

def _strip_bars_outside_comments(s: str) -> str:
    com, out = False, []
    for ch in s:
        if ch in ('\n', '\r'): com = False; out.append(ch)
        elif ch == '#':        com = True;  out.append(ch)
        elif ch == '|' and not com: continue
        else: out.append(ch)
    return ''.join(out)

def _replace_tokens(text: str, repl: Dict[int, str]) -> str:
    if not repl: return text
    sp = _spans(text)
    for idx in sorted(repl.keys(), reverse=True):
        if 0 <= idx < len(sp):
            st, en = sp[idx]
            text = text[:st] + repl[idx] + text[en:]
    return text

def _drop_tokens_by_indices(text: str, idxs: List[int]) -> str:
    if not idxs: return text
    out = text
    for idx in sorted(set(idxs), reverse=True):
        sp = _spans(out)  # recompute spans after each deletion
        if 0 <= idx < len(sp):
            st, en = sp[idx]
            out = _erase_span_and_one_sep(out, st, en)
    return out

# ---------------- identity for dedupe ----------------

def _default_path_key(p: str) -> str:
    s = p.strip().replace('\\', '/')
    while '//' in s: s = s.replace('//', '/')
    if len(s) > 1 and s.endswith('/'): s = s[:-1]
    return s

# ---------------- new-set splitter (FIX) ----------------

def _select_new_side(
    loras_new: List[str],
    mult_new: str,
    mode: str,  # "merge before" | "merge after"
) -> Tuple[List[str], str]:
    """
    Split mult_new on '|' (outside comments) and split loras_new accordingly.
    Return ONLY the side relevant to `mode`. Extras loras (if any) are appended to the selected side.
    """
    bi = _find_bar(mult_new)
    if bi == -1:
        return loras_new, _strip_bars_outside_comments(mult_new)

    left, right = mult_new[:bi], mult_new[bi + 1:]
    nL, nR = len(_spans(left)), len(_spans(right))
    L = len(loras_new)

    # Primary allocation by token counts
    b_count = min(nL, L)
    rem     = max(0, L - b_count)
    a_count = min(nR, rem)
    extras  = max(0, L - (b_count + a_count))

    if mode == "merge before":
        # take BEFORE loras + extras
        l_sel = loras_new[:b_count] + (loras_new[b_count + a_count : b_count + a_count + extras] if extras else [])
        m_sel = left
    else:
        # take AFTER loras + extras
        start_after = b_count
        l_sel = loras_new[start_after:start_after + a_count] + (loras_new[start_after + a_count : start_after + a_count + extras] if extras else [])
        m_sel = right

    return l_sel, _strip_bars_outside_comments(m_sel)

# ---------------- public API ----------------

def merge_loras_settings(
    loras_old: List[str],
    mult_old: str,
    loras_new: List[str],
    mult_new: str,
    mode: str = "merge before",
    path_key: Callable[[str], str] = _default_path_key,
) -> Tuple[List[str], str]:
    """
    Merge settings with full formatting/comment preservation and correct handling of `mult_new` with '|'.
    Dedup rule: when merging AFTER (resp. BEFORE), if a new lora already exists in preserved BEFORE (resp. AFTER),
    update that preserved multiplier and drop the duplicate from the replaced side.
    """
    assert mode in ("merge before", "merge after")

    # Old split & alignment
    bi_old = _find_bar(mult_old)
    before_old, after_old = (mult_old[:bi_old], mult_old[bi_old + 1:]) if bi_old != -1 else ("", mult_old)
    orig_had_bar = (bi_old != -1)

    sp_b_old, sp_a_old = _spans(before_old), _spans(after_old)
    n_b_old = len(sp_b_old)
    total_old = len(loras_old)

    if n_b_old <= total_old:
        keep_b = n_b_old
        keep_a = total_old - keep_b
        before_old_aligned = before_old
        after_old_aligned  = _enforce_count(after_old, keep_a)
    else:
        keep_b = total_old
        keep_a = 0
        before_old_aligned = _enforce_count(before_old, keep_b)
        after_old_aligned  = _enforce_count(after_old, 0)

    # NEW: choose the relevant side of the *new* set (fix for '|' in mult_new)
    loras_new_sel, mult_new_sel = _select_new_side(loras_new, mult_new, mode)
    mult_new_aligned = _enforce_count(mult_new_sel, len(loras_new_sel))
    sp_new = _spans(mult_new_aligned)
    new_tokens = [mult_new_aligned[st:en] for st, en in sp_new]

    if mode == "merge after":
        # Preserve BEFORE; replace AFTER (with dedupe/update)
        preserved_loras = loras_old[:keep_b]
        preserved_text  = before_old_aligned
        preserved_spans = _spans(preserved_text)
        pos_by_key: Dict[str, int] = {}
        for i, lp in enumerate(preserved_loras):
            k = path_key(lp)
            if k not in pos_by_key: pos_by_key[k] = i

        repl_map: Dict[int, str] = {}
        drop_idxs: List[int] = []
        for i, lp in enumerate(loras_new_sel):
            j = pos_by_key.get(path_key(lp))
            if j is not None and j < len(preserved_spans):
                repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1"
                drop_idxs.append(i)

        before_text = _replace_tokens(preserved_text, repl_map)
        after_text  = _drop_tokens_by_indices(mult_new_aligned, drop_idxs)
        loras_keep  = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)]
        loras_out   = preserved_loras + loras_keep

    else:
        # Preserve AFTER; replace BEFORE (with dedupe/update)
        preserved_loras = loras_old[keep_b:]
        preserved_text  = after_old_aligned
        preserved_spans = _spans(preserved_text)
        pos_by_key: Dict[str, int] = {}
        for i, lp in enumerate(preserved_loras):
            k = path_key(lp)
            if k not in pos_by_key: pos_by_key[k] = i

        repl_map: Dict[int, str] = {}
        drop_idxs: List[int] = []
        for i, lp in enumerate(loras_new_sel):
            j = pos_by_key.get(path_key(lp))
            if j is not None and j < len(preserved_spans):
                repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1"
                drop_idxs.append(i)

        after_text  = _replace_tokens(preserved_text, repl_map)
        before_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs)
        loras_keep  = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)]
        loras_out   = loras_keep + preserved_loras

    # Compose, preserving explicit "before-only" bar when appropriate
    has_before = len(_spans(before_text)) > 0
    has_after  = len(_spans(after_text)) > 0
    if has_before and has_after:
        mult_out = f"{before_text}|{after_text}"
    elif has_before:
        mult_out = before_text + ('|' if (mode == 'merge before' or orig_had_bar) else '')
    else:
        mult_out = after_text

    return loras_out, mult_out

# ---------------- extractor ----------------

def extract_loras_side(
    loras: List[str],
    mult: str,
    which: str = "before",
) -> Tuple[List[str], str]:
    assert which in ("before", "after")
    bi = _find_bar(mult)
    before_txt, after_txt = (mult[:bi], mult[bi + 1:]) if bi != -1 else ("", mult)

    sp_b = _spans(before_txt)
    n_b  = len(sp_b)
    total = len(loras)

    if n_b <= total:
        keep_b = n_b
        keep_a = total - keep_b
    else:
        keep_b = total
        keep_a = 0

    if which == "before":
        return loras[:keep_b], _enforce_count(before_txt, keep_b)
    else:
        return loras[keep_b:keep_b + keep_a], _enforce_count(after_txt, keep_a)