dikdimon commited on
Commit
3d4fdb4
Β·
verified Β·
1 Parent(s): 37b6731

Delete __init__.py

Browse files
Files changed (1) hide show
  1. __init__.py +0 -559
__init__.py DELETED
@@ -1,559 +0,0 @@
1
- """
2
- lib_mega_freeu/unet.py β€” Math engine + A1111 th.cat patch
3
-
4
- BUGS FIXED vs sdwebui-freeU-extension/scripts/freeunet_hijack.py:
5
- BUG 1 dtype: mask = torch.ones(..., dtype=torch.bool)
6
- bool*float = NOOP, scale always 1.0
7
- Fix: torch.full(..., float(scale_high))
8
- BUG 2 quadrant: mask[..., crow-t:crow, ccol-t:ccol] (top-left only)
9
- Fix: mask[..., crow-t:crow+t, ccol-t:ccol+t] (symmetric center)
10
-
11
- Sources:
12
- sd-webui-freeu/lib_free_u/unet.py patch(), free_u_cat_hijack(),
13
- get_backbone_scale(), ratio_to_region(), filter_skip()[box],
14
- get_schedule_ratio(), is_gpu_complex_supported(), lerp()
15
- WAS FreeU_Advanced/nodes.py 9 blending modes, Fourier_filter() multiscale
16
- ComfyUI_FreeU_V2_advanced/utils.py Fourier_filter_gauss(), get_band_energy_stats()
17
- ComfyUI_FreeU_V2_advanced/FreeU_S1S2.py Adaptive Cap loop MAX_CAP_ITER=3
18
- ComfyUI_FreeU_V2_advanced/FreeU_B1B2.py channel_threshold, model_channels*4/2/1
19
- FreeU_V2_timestepadd.py step-fraction timestep gating concept
20
- nrs_kohaku_enhanced_v3_5.py _freeu_b_scale_h, _freeu_fourier_filter_gaussian,
21
- hf_boost param, on_cpu_devices dict
22
- """
23
- import dataclasses
24
- import functools
25
- import logging
26
- import math
27
- import pathlib
28
- import sys
29
- from typing import Dict, List, Optional, Tuple, Union
30
-
31
- import torch
32
- from lib_mega_freeu import global_state
33
-
34
- # ── GPU complex support (sd-webui-freeu exact) ────────────────────────────────
35
- _gpu_complex_support: Optional[bool] = None
36
-
37
- def is_gpu_complex_supported(x: torch.Tensor) -> bool:
38
- global _gpu_complex_support
39
- if x.is_cpu:
40
- return True
41
- if _gpu_complex_support is not None:
42
- return _gpu_complex_support
43
- mps_avail = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
44
- try:
45
- import torch_directml
46
- except ImportError:
47
- dml_avail = False
48
- else:
49
- dml_avail = torch_directml.is_available()
50
- _gpu_complex_support = not (mps_avail or dml_avail)
51
- if _gpu_complex_support:
52
- try: torch.fft.fftn(x.float(), dim=(-2, -1))
53
- except RuntimeError: _gpu_complex_support = False
54
- return _gpu_complex_support
55
-
56
- _on_cpu_devices: Dict = {}
57
-
58
- # ── Blending modes (WAS nodes.py exact) ───────────────────────────────────────
59
- def _normalize(t):
60
- mn, mx = t.min(), t.max()
61
- return (t - mn) / (mx - mn + 1e-8)
62
-
63
- def _hslerp(a, b, t):
64
- nc = a.size(1)
65
- iv = torch.zeros(1, nc, 1, 1, device=a.device, dtype=a.dtype)
66
- iv[0, 0, 0, 0] = 1.0
67
- result = (1 - t) * a + t * b
68
- if t < 0.5:
69
- result += (torch.norm(b - a, dim=1, keepdim=True) / 6) * iv
70
- else:
71
- result -= (torch.norm(b - a, dim=1, keepdim=True) / 6) * iv
72
- return result
73
-
74
- def _stable_slerp(a, b, t, eps=1e-6):
75
- an = a / torch.linalg.norm(a, dim=1, keepdim=True).clamp_min(eps)
76
- bn = b / torch.linalg.norm(b, dim=1, keepdim=True).clamp_min(eps)
77
- dot = (an * bn).sum(dim=1, keepdim=True).clamp(-1.0 + eps, 1.0 - eps)
78
- theta = torch.acos(dot)
79
- sin_t = torch.sin(theta).clamp_min(eps)
80
- s0 = torch.sin((1.0 - t) * theta) / sin_t
81
- s1 = torch.sin(t * theta) / sin_t
82
- slerp_out = s0 * a + s1 * b
83
- lerp_out = (1.0 - t) * a + t * b
84
- use_lerp = (theta < 1e-3).squeeze(1)
85
- return torch.where(use_lerp.unsqueeze(1), lerp_out, slerp_out)
86
-
87
- BLENDING_MODES = {
88
- "bislerp": lambda a, b, t: _normalize((1 - t) * a + t * b),
89
- "colorize": lambda a, b, t: a + (b - a) * t,
90
- "cosine interp": lambda a, b, t: (
91
- a + b - (a - b) * torch.cos(
92
- torch.tensor(math.pi, device=a.device, dtype=a.dtype) * t)) / 2,
93
- "cuberp": lambda a, b, t: a + (b - a) * (3 * t**2 - 2 * t**3),
94
- "hslerp": _hslerp,
95
- "stable_slerp": _stable_slerp,
96
- "inject": lambda a, b, t: a + b * t,
97
- "lerp": lambda a, b, t: (1 - t) * a + t * b,
98
- "linear dodge": lambda a, b, t: _normalize(a + b * t),
99
- }
100
-
101
- def lerp(a, b, r):
102
- return (1 - r) * a + r * b
103
-
104
- # ── Backbone scaling ──────────────────────────────────────────────────────────
105
- def get_backbone_scale(h: torch.Tensor, backbone_factor: float, version: str):
106
- if version == "1":
107
- return backbone_factor
108
- # V2: adaptive hidden_mean (FreeU_B1B2.py + kohaku _freeu_b_scale_h exact)
109
- features_mean = h.mean(1, keepdim=True)
110
- B = features_mean.shape[0]
111
- hidden_max, _ = torch.max(features_mean.view(B, -1), dim=-1, keepdim=True)
112
- hidden_min, _ = torch.min(features_mean.view(B, -1), dim=-1, keepdim=True)
113
- denom = (hidden_max - hidden_min).clamp_min(1e-6)
114
- hidden_mean = (features_mean - hidden_min.unsqueeze(2).unsqueeze(3)) \
115
- / denom.unsqueeze(2).unsqueeze(3)
116
- return 1 + (backbone_factor - 1) * hidden_mean
117
-
118
- def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
119
- """sd-webui-freeu ratio_to_region exact."""
120
- if width < 0:
121
- offset += width; width = -width
122
- width = min(width, 1.0)
123
- if offset < 0:
124
- offset = 1 + offset - int(offset)
125
- offset = math.fmod(offset, 1.0)
126
- if width + offset <= 1:
127
- return round(offset * n), round((width + offset) * n), False
128
- else:
129
- return round((width + offset - 1) * n), round(offset * n), True
130
-
131
- # ── Box FFT (BUGS FIXED symmetric center + float dtype) ─────────────────────
132
- def filter_skip_box(x: torch.Tensor, cutoff: float,
133
- scale: float, scale_high: float = 1.0) -> torch.Tensor:
134
- """
135
- FreeU box filter with TWO BUGS FIXED from sdwebui-freeU-extension:
136
- BUG 1 (dtype): was torch.bool mask -> scale multiplication was NOOP
137
- BUG 2 (region): was [crow-t:crow, ccol-t:ccol] -> single quadrant top-left
138
- Both fixed: torch.full float + symmetric [crow-t:crow+t, ccol-t:ccol+t].
139
- sd-webui-freeu has these correct already, we match their implementation.
140
- """
141
- if scale == 1.0 and scale_high == 1.0:
142
- return x
143
- fft_dev = x.device if is_gpu_complex_supported(x) else torch.device("cpu")
144
- x_freq = torch.fft.fftn(x.to(fft_dev, dtype=torch.float32), dim=(-2, -1))
145
- x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
146
- B, C, H, W = x_freq.shape
147
- mask = torch.full((B, C, H, W), float(scale_high), device=fft_dev) # FIX: float, not bool
148
- crow, ccol = H // 2, W // 2
149
- tr = max(1, math.floor(crow * cutoff)) if cutoff > 0 else 1
150
- tc = max(1, math.floor(ccol * cutoff)) if cutoff > 0 else 1
151
- mask[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = scale # FIX: symmetric center
152
- x_freq *= mask
153
- x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
154
- return torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype)
155
-
156
- # ── Box + WAS multiscale overlay (WAS nodes.py Fourier_filter exact) ─────────
157
- def filter_skip_box_multiscale(x: torch.Tensor, cutoff: float, scale: float,
158
- scales_preset: Optional[list],
159
- strength: float = 1.0,
160
- scale_high: float = 1.0) -> torch.Tensor:
161
- """
162
- WAS FreeU_Advanced/nodes.py Fourier_filter(x, threshold, scale, scales, strength).
163
- threshold = cutoff: float ratio [0-1] or int pixels (WAS uses int default=1).
164
- scales: None, list of (radius_px, val) single-scale, or list of lists multi-scale.
165
- """
166
- if scale == 1.0 and scale_high == 1.0 and scales_preset is None:
167
- return x
168
- fft_dev = x.device if is_gpu_complex_supported(x) else torch.device("cpu")
169
- x_freq = torch.fft.fftn(x.to(fft_dev, dtype=torch.float32), dim=(-2, -1))
170
- x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
171
- B, C, H, W = x_freq.shape
172
- crow, ccol = H // 2, W // 2
173
- if isinstance(cutoff, float) and 0 < cutoff <= 1.0:
174
- tr = max(1, math.floor(crow * cutoff)); tc = max(1, math.floor(ccol * cutoff))
175
- else:
176
- t = max(1, int(cutoff)) if cutoff > 0 else 1; tr = tc = t
177
- mask = torch.ones((B, C, H, W), device=fft_dev)
178
- mask[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = scale
179
- if scale_high != 1.0:
180
- hfm = torch.full((B, C, H, W), float(scale_high), device=fft_dev)
181
- hfm[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = 1.0
182
- mask = mask * hfm
183
- if scales_preset:
184
- if isinstance(scales_preset[0], tuple):
185
- # WAS single-scale mode
186
- for scale_threshold, scale_value in scales_preset:
187
- sv = scale_value * strength
188
- sm = torch.ones((B, C, H, W), device=fft_dev)
189
- st = max(1, int(scale_threshold))
190
- sm[..., crow - st:crow + st, ccol - st:ccol + st] = sv
191
- mask = mask + (sm - mask) * strength
192
- else:
193
- # WAS multi-scale mode
194
- for scale_params in scales_preset:
195
- if isinstance(scale_params, list):
196
- for scale_threshold, scale_value in scale_params:
197
- sv = scale_value * strength
198
- sm = torch.ones((B, C, H, W), device=fft_dev)
199
- st = max(1, int(scale_threshold))
200
- sm[..., crow - st:crow + st, ccol - st:ccol + st] = sv
201
- mask = mask + (sm - mask) * strength
202
- x_freq = x_freq * mask
203
- x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
204
- return torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype)
205
-
206
- # ── Gaussian FFT (ComfyUI utils.py exact) ────────────────────────────────────
207
- def fourier_filter_gauss(x: torch.Tensor, radius_ratio: float,
208
- scale: float, hf_boost: float = 1.0) -> torch.Tensor:
209
- """
210
- ComfyUI_FreeU_V2_advanced/utils.py Fourier_filter_gauss() exact.
211
- Also matches kohaku _freeu_fourier_filter_gaussian().
212
- R = max(1, int(min(H,W)*radius_ratio))
213
- sigma_f = R^2/2
214
- center = exp(-dist2/sigma_f)
215
- mask = scale*center + hf_boost*(1-center)
216
- """
217
- x_f = torch.fft.fftn(x.float(), dim=(-2, -1))
218
- x_f = torch.fft.fftshift(x_f, dim=(-2, -1))
219
- B, C, H, W = x_f.shape
220
- R = max(1, int(min(H, W) * radius_ratio))
221
- sigma_f = max(1e-6, (R * R) / 2.0)
222
- yy, xx = torch.meshgrid(
223
- torch.arange(H, device=x.device, dtype=torch.float32) - H // 2,
224
- torch.arange(W, device=x.device, dtype=torch.float32) - W // 2,
225
- indexing="ij")
226
- center = torch.exp(-(yy**2 + xx**2) / sigma_f)
227
- mask = (scale * center + hf_boost * (1.0 - center)).view(1, 1, H, W)
228
- x_f = x_f * mask
229
- x_f = torch.fft.ifftshift(x_f, dim=(-2, -1))
230
- return torch.fft.ifftn(x_f, dim=(-2, -1)).real.to(x.dtype)
231
-
232
- # ── Band energy stats (ComfyUI utils.py exact) ────────────────────────────────
233
- def get_band_energy_stats(x: torch.Tensor, R: int) -> Tuple[float, float, float]:
234
- """ComfyUI_FreeU_V2_advanced/utils.py get_band_energy_stats() exact."""
235
- xf = torch.fft.fftn(x.float(), dim=(-2, -1))
236
- xf = torch.fft.fftshift(xf, dim=(-2, -1))
237
- B, C, H, W = xf.shape
238
- yy, xx = torch.meshgrid(
239
- torch.arange(H, device=x.device, dtype=torch.float32) - H // 2,
240
- torch.arange(W, device=x.device, dtype=torch.float32) - W // 2,
241
- indexing="ij")
242
- lf_mask = (yy**2 + xx**2) <= (R * R)
243
- mag2 = xf.real**2 + xf.imag**2
244
- # FIX: expand_as requires same ndim; use 2D mask on last dims
245
- lf_e = mag2[:, :, lf_mask].mean().item() if lf_mask.any() else 0.0
246
- hf_e = mag2[:, :, ~lf_mask].mean().item() if (~lf_mask).any() else 0.0
247
- cover = lf_mask.sum().item() / (H * W) * 100.0
248
- return lf_e, hf_e, cover
249
-
250
- # ── Adaptive Cap Gaussian (FreeU_S1S2.py MAX_CAP_ITER=3 exact) ───────────────
251
- def filter_skip_gaussian_adaptive(hsp: torch.Tensor,
252
- si: "global_state.StageInfo",
253
- verbose: bool = False) -> torch.Tensor:
254
- """
255
- ComfyUI_FreeU_V2_advanced/FreeU_S1S2.py exact algorithm:
256
- 1. Compute LF/HF ratio before.
257
- 2. Apply Gaussian filter.
258
- 3. If enable_adaptive_cap and drop > cap_threshold: loop up to MAX_CAP_ITER=3.
259
- adaptive mode: eff_factor = cap_factor * (cap_threshold / drop)
260
- fixed mode: eff_factor = cap_factor
261
- capped_s = 1 - eff_factor*(1-s_scale) [interpolate FROM ORIGINAL]
262
- capped_s = max(capped_s, current_s*(1+1e-4))
263
- Re-apply from original_hsp with capped_s.
264
- hf_boost combined = max(si.hf_boost, si.skip_high_end_factor) [kohaku pattern]
265
- """
266
- s_scale = si.skip_factor
267
- radius_r = si.fft_radius_ratio
268
- hf_boost = max(si.hf_boost, si.skip_high_end_factor)
269
- orig_dev = hsp.device
270
- H, W = hsp.shape[-2:]
271
- R_eff = max(1, int(min(H, W) * radius_r))
272
-
273
- # CRITICAL ORDER: init cpu-fallback flag and helpers BEFORE any FFT call
274
- use_cpu = _on_cpu_devices.get(orig_dev, not is_gpu_complex_supported(hsp))
275
- if use_cpu:
276
- _on_cpu_devices[orig_dev] = True
277
-
278
- def _tod(t): # to FFT-safe device
279
- return t.cpu() if use_cpu else t
280
-
281
- def _fromd(t): # back to original device
282
- return t.to(orig_dev) if use_cpu else t
283
-
284
- def _energy(t):
285
- return get_band_energy_stats(_tod(t), R_eff)
286
-
287
- def _filt(inp, sc):
288
- nonlocal use_cpu
289
- try:
290
- out = fourier_filter_gauss(_tod(inp), radius_r, sc, hf_boost)
291
- return _fromd(out)
292
- except Exception:
293
- if not use_cpu:
294
- logging.warning(f"[MegaFreeU] {orig_dev} -> CPU fallback for FFT")
295
- _on_cpu_devices[orig_dev] = True
296
- use_cpu = True
297
- return fourier_filter_gauss(inp.cpu(), radius_r, sc, hf_boost).to(orig_dev)
298
- return inp
299
-
300
- # Pre-filter energy (now safe on all devices)
301
- lf_b, hf_b, cover = _energy(hsp)
302
- ratio_b = lf_b / hf_b if hf_b > 1e-6 else float("inf")
303
- if verbose:
304
- logging.info(f"[MegaFreeU] Gauss {H}x{W} R={R_eff}px cov={cover:.1f}% "
305
- f"LF={lf_b:.3e} HF={hf_b:.3e} ratio_b={ratio_b:.4f}")
306
-
307
- hsp_filt = _filt(hsp, s_scale)
308
- if not si.enable_adaptive_cap:
309
- return hsp_filt
310
-
311
- MAX_CAP_ITER = 3
312
- original_hsp = hsp
313
- current_s = s_scale
314
- lf_a, hf_a, _ = _energy(hsp_filt)
315
- ratio_a = lf_a / hf_a if hf_a > 1e-6 else float("inf")
316
- drop = 1.0 - (ratio_a / ratio_b) if ratio_b > 1e-6 else 0.0
317
- orig_drop = drop
318
- iters = 0
319
- hsp_cur = hsp_filt
320
-
321
- while (si.enable_adaptive_cap
322
- and drop > si.cap_threshold
323
- and current_s < 0.999
324
- and iters < MAX_CAP_ITER):
325
-
326
- if iters == 0:
327
- logging.warning(f"[MegaFreeU] Over-attenuation: drop={drop*100:.1f}% > "
328
- f"{si.cap_threshold*100:.1f}% s={s_scale:.4f}")
329
-
330
- eff_f = si.cap_factor
331
- if si.adaptive_cap_mode == "adaptive":
332
- eff_f = si.cap_factor * (si.cap_threshold / max(drop, 1e-8))
333
-
334
- capped_s = 1.0 - eff_f * (1.0 - s_scale) # interpolate from ORIGINAL s
335
- capped_s = max(capped_s, current_s * (1.0 + 1e-4)) # only ever relax
336
- if abs(capped_s - current_s) < 1e-4:
337
- if verbose: logging.info(" Cap converged.")
338
- break
339
-
340
- if verbose:
341
- logging.info(f" Cap iter {iters+1}: s {current_s:.4f}->{capped_s:.4f} eff={eff_f:.4f}")
342
-
343
- try:
344
- hsp_new = _filt(original_hsp, capped_s)
345
- except Exception as e:
346
- logging.error(f"[MegaFreeU] cap re-apply error: {e}")
347
- hsp_cur = original_hsp # restore to original on error (ComfyUI FreeU_S1S2.py pattern)
348
- break
349
-
350
- hsp_cur = hsp_new
351
- lf_a, hf_a, _ = _energy(hsp_cur)
352
- ratio_a = lf_a / hf_a if hf_a > 1e-6 else float("inf")
353
- drop = 1.0 - (ratio_a / ratio_b) if ratio_b > 1e-6 else 0.0
354
- current_s = capped_s
355
- iters += 1
356
-
357
- if iters > 0 or verbose:
358
- logging.info(f"[MegaFreeU] Cap done: {orig_drop*100:.1f}%->{drop*100:.1f}% "
359
- f"({iters} iters s_final={current_s:.4f})")
360
- return hsp_cur
361
-
362
- # ── Schedule (sd-webui-freeu exact) ──────────────────────────────────────────
363
- def get_schedule_ratio() -> float:
364
- from modules import shared
365
- st = global_state.instance
366
- steps = shared.state.sampling_steps or 20
367
- cur = global_state.current_sampling_step
368
- start = _to_step(st.start_ratio, steps)
369
- stop = _to_step(st.stop_ratio, steps)
370
- if start == stop:
371
- smooth = 0.0
372
- elif cur < start:
373
- smooth = min(1.0, max(0.0, cur / (start + 1e-8)))
374
- else:
375
- smooth = min(1.0, max(0.0, 1 + (cur - start) / (start - stop + 1e-8)))
376
- flat = 1.0 if start <= cur < stop else 0.0
377
- return lerp(flat, smooth, st.transition_smoothness)
378
-
379
- def get_stage_bsratio(b_start: float, b_end: float) -> float:
380
- """Independent B/S timestep range gate (FreeU_V2_timestepadd concept -> step fraction)."""
381
- from modules import shared
382
- steps = max(shared.state.sampling_steps or 20, 1)
383
- cur = global_state.current_sampling_step
384
- pct = cur / (steps - 1) if steps > 1 else 0.0
385
- return 1.0 if b_start <= pct <= b_end else 0.0
386
-
387
- def _to_step(v, steps):
388
- return int(v * steps) if isinstance(v, float) else int(v)
389
-
390
- # ── Stage auto-detection (FreeU_B1B2.py + kohaku exact) ──────────────────────
391
- _stage_channels: Tuple[int, int, int] = (1280, 640, 320)
392
-
393
- def detect_model_channels():
394
- global _stage_channels
395
- try:
396
- from modules import shared
397
- mc = int(shared.sd_model.model.diffusion_model.model_channels)
398
- _stage_channels = (mc * 4, mc * 2, mc * 1)
399
- except Exception:
400
- _stage_channels = (1280, 640, 320)
401
-
402
- def get_stage_index(dims: int, channel_threshold: int = 96) -> Optional[int]:
403
- """FreeU_B1B2.py abs(ch - target) <= channel_threshold proximity match."""
404
- for i, target in enumerate(_stage_channels):
405
- if abs(dims - target) <= channel_threshold:
406
- return i
407
- return None
408
-
409
- # ── Override scales parser (WAS nodes.py format exact) ───────────────────────
410
- def parse_override_scales(text: str) -> Optional[List]:
411
- if not text or not text.strip():
412
- return None
413
- result = []
414
- for line in text.strip().splitlines():
415
- line = line.strip()
416
- if not line or line.startswith(("#", "!", "//")):
417
- continue
418
- parts = line.split(",")
419
- if len(parts) == 2:
420
- try:
421
- result.append((int(parts[0].strip()), float(parts[1].strip())))
422
- except ValueError:
423
- pass
424
- return result if result else None
425
-
426
- class _VerboseRef:
427
- value: bool = False
428
- verbose_ref = _VerboseRef()
429
-
430
- # ── Core th.cat hijack (sd-webui-freeu exact + extended) ─────────────────────
431
- def free_u_cat_hijack(hs, *args, original_function, **kwargs):
432
- """
433
- Intercepts torch.cat([h, h_skip], dim=1) in UNet output_blocks.
434
- Signature: kwargs=={"dim":1} and len(hs)==2 (sd-webui-freeu exact check).
435
-
436
- Why th.cat over alternatives:
437
- - sdwebui-freeU-extension CondFunc(UNetModel.forward): rewrites full forward,
438
- incompatible with other extensions, plus 2 bugs in fourier mask.
439
- - kohaku register_forward_hook: output already concatenated,
440
- can't cleanly separate h from h_skip for independent filtering.
441
- - th.cat hijack: intercepts exactly [h, h_skip] before concatenation. CORRECT.
442
- """
443
- st = global_state.instance
444
- if not st.enable:
445
- return original_function(hs, *args, **kwargs)
446
-
447
- sched = get_schedule_ratio()
448
- if sched == 0:
449
- return original_function(hs, *args, **kwargs)
450
-
451
- try:
452
- h, h_skip = hs
453
- if list(kwargs.keys()) != ["dim"] or kwargs.get("dim", -1) != 1:
454
- return original_function(hs, *args, **kwargs)
455
- except (ValueError, TypeError):
456
- return original_function(hs, *args, **kwargs)
457
-
458
- dims = int(h.shape[1])
459
- stage_idx = get_stage_index(dims, st.channel_threshold)
460
- if stage_idx is None:
461
- return original_function(hs, *args, **kwargs)
462
-
463
- si = st.stage_infos[stage_idx]
464
- version = st.version
465
- verbose = verbose_ref.value
466
-
467
- # ── BACKBONE ─────────────────────────────────────────────────────────────
468
- b_gate = get_stage_bsratio(si.b_start_ratio, si.b_end_ratio)
469
- eff_b = sched * b_gate
470
-
471
- if eff_b > 0.0 and abs(si.backbone_factor - 1.0) > 1e-6:
472
- try:
473
- rbegin, rend, rinv = ratio_to_region(si.backbone_width, si.backbone_offset, dims)
474
- ch_idx = torch.arange(dims, device=h.device)
475
- mask = (rbegin <= ch_idx) & (ch_idx <= rend)
476
- if rinv: mask = ~mask
477
- mask = mask.reshape(1, -1, 1, 1).to(h.dtype)
478
-
479
- eff_factor = float(lerp(1.0, si.backbone_factor, eff_b))
480
- scale = get_backbone_scale(h, eff_factor, version)
481
- # h_scaled_full: full h with mask region scaled, rest unchanged
482
- # This matches original: h *= mask*scale + (1-mask)
483
- h_scaled_full = h * (mask * scale + (1.0 - mask))
484
-
485
- bmode = si.backbone_blend_mode
486
- if bmode in BLENDING_MODES and abs(si.backbone_blend - 1.0) > 1e-6:
487
- # Blend on FULL tensors so modes like slerp/hslerp see proper norms.
488
- # Then restore unmasked channels to original h.
489
- h_blended = BLENDING_MODES[bmode](h, h_scaled_full, si.backbone_blend)
490
- h = h * (1.0 - mask) + h_blended * mask
491
- else:
492
- h = h_scaled_full
493
- except Exception as e:
494
- logging.warning(f"[MegaFreeU] B-scaling stage {stage_idx}: {e}")
495
-
496
- # ── SKIP / FOURIER ────────────────────────────────────────────────────────
497
- s_gate = get_stage_bsratio(si.s_start_ratio, si.s_end_ratio)
498
- eff_s = sched * s_gate
499
-
500
- if eff_s > 0.0 and (abs(si.skip_factor - 1.0) > 1e-6
501
- or abs(si.hf_boost - 1.0) > 1e-6
502
- or abs(si.skip_high_end_factor - 1.0) > 1e-6):
503
- try:
504
- s_scale = float(lerp(1.0, si.skip_factor, eff_s))
505
- s_high = float(lerp(1.0, si.skip_high_end_factor, eff_s))
506
-
507
- if si.fft_type == "gaussian":
508
- hf_b = float(lerp(1.0, si.hf_boost, eff_s))
509
- si_eff = dataclasses.replace(si, skip_factor=s_scale, skip_high_end_factor=s_high, hf_boost=hf_b)
510
- h_skip = filter_skip_gaussian_adaptive(h_skip, si_eff, verbose)
511
- else:
512
- override = parse_override_scales(st.override_scales)
513
- ms_preset = override or global_state.MSCALES.get(st.multiscale_mode)
514
- if ms_preset is not None:
515
- h_skip = filter_skip_box_multiscale(
516
- h_skip, si.skip_cutoff, s_scale, ms_preset,
517
- st.multiscale_strength, s_high)
518
- else:
519
- h_skip = filter_skip_box(h_skip, si.skip_cutoff, s_scale, s_high)
520
- except Exception as e:
521
- logging.warning(f"[MegaFreeU] skip filter stage {stage_idx}: {e}")
522
-
523
- return original_function([h, h_skip], *args, **kwargs)
524
-
525
- # ── Patch (sd-webui-freeu exact + ControlNet) ─────────────────────────────────
526
- _patched = False # guard against double-patch on hot-reload
527
-
528
- def patch():
529
- global _patched
530
- try:
531
- from modules.sd_hijack_unet import th
532
- except ImportError:
533
- print("[MegaFreeU] sd_hijack_unet not available", file=sys.stderr); return
534
-
535
- if _patched or (hasattr(th.cat, "func") and getattr(th.cat.func, "__name__", "") == "free_u_cat_hijack"):
536
- return # already patched (by name; handles module reload)
537
- th.cat = functools.partial(free_u_cat_hijack, original_function=th.cat)
538
- _patched = True
539
-
540
- cn_status = "enabled"
541
- try:
542
- from modules import scripts
543
- cn_paths = [
544
- str(pathlib.Path(scripts.basedir()).parent.parent / "extensions-builtin" / "sd-webui-controlnet"),
545
- str(pathlib.Path(scripts.basedir()).parent / "sd-webui-controlnet"),
546
- ]
547
- sys.path[0:0] = cn_paths
548
- try:
549
- import scripts.hook as cn_hook
550
- cn_hook.th.cat = functools.partial(free_u_cat_hijack, original_function=cn_hook.th.cat)
551
- except ImportError:
552
- cn_status = "disabled"
553
- finally:
554
- for p in cn_paths:
555
- if p in sys.path: sys.path.remove(p)
556
- except Exception:
557
- cn_status = "error"
558
-
559
- print(f"[MegaFreeU] th.cat patched ControlNet: *{cn_status}*")