|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.fft as fft | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): | 
					
						
						|  | """ | 
					
						
						|  | Apply frequency-dependent scaling to an image tensor using Fourier transforms. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | x:           Input tensor of shape (B, C, H, W) | 
					
						
						|  | scale_low:   Scaling factor for low-frequency components (default: 1.0) | 
					
						
						|  | scale_high:  Scaling factor for high-frequency components (default: 1.5) | 
					
						
						|  | freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | dtype, device = x.dtype, x.device | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = x.to(torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_freq = fft.fftn(x, dim=(-2, -1)) | 
					
						
						|  | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask = torch.ones(x_freq.shape, device=device) * scale_high | 
					
						
						|  | m = mask | 
					
						
						|  | for d in range(len(x_freq.shape) - 2): | 
					
						
						|  | dim = d + 2 | 
					
						
						|  | cc = x_freq.shape[dim] // 2 | 
					
						
						|  | f_c = min(freq_cutoff, cc) | 
					
						
						|  | m = m.narrow(dim, cc - f_c, f_c * 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | m[:] = scale_low | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_freq = x_freq * mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) | 
					
						
						|  | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_filtered = x_filtered.to(dtype) | 
					
						
						|  |  | 
					
						
						|  | return x_filtered | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FreSca: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "model": ("MODEL",), | 
					
						
						|  | "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, | 
					
						
						|  | "tooltip": "Scaling factor for low-frequency components"}), | 
					
						
						|  | "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, | 
					
						
						|  | "tooltip": "Scaling factor for high-frequency components"}), | 
					
						
						|  | "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, | 
					
						
						|  | "tooltip": "Number of frequency indices around center to consider as low-frequency"}), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | RETURN_TYPES = ("MODEL",) | 
					
						
						|  | FUNCTION = "patch" | 
					
						
						|  | CATEGORY = "_for_testing" | 
					
						
						|  | DESCRIPTION = "Applies frequency-dependent scaling to the guidance" | 
					
						
						|  | def patch(self, model, scale_low, scale_high, freq_cutoff): | 
					
						
						|  | def custom_cfg_function(args): | 
					
						
						|  | conds_out = args["conds_out"] | 
					
						
						|  | if len(conds_out) <= 1 or None in args["conds"][:2]: | 
					
						
						|  | return conds_out | 
					
						
						|  | cond = conds_out[0] | 
					
						
						|  | uncond = conds_out[1] | 
					
						
						|  |  | 
					
						
						|  | guidance = cond - uncond | 
					
						
						|  | filtered_guidance = Fourier_filter( | 
					
						
						|  | guidance, | 
					
						
						|  | scale_low=scale_low, | 
					
						
						|  | scale_high=scale_high, | 
					
						
						|  | freq_cutoff=freq_cutoff, | 
					
						
						|  | ) | 
					
						
						|  | filtered_cond = filtered_guidance + uncond | 
					
						
						|  |  | 
					
						
						|  | return [filtered_cond, uncond] + conds_out[2:] | 
					
						
						|  |  | 
					
						
						|  | m = model.clone() | 
					
						
						|  | m.set_model_sampler_pre_cfg_function(custom_cfg_function) | 
					
						
						|  |  | 
					
						
						|  | return (m,) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | NODE_CLASS_MAPPINGS = { | 
					
						
						|  | "FreSca": FreSca, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | NODE_DISPLAY_NAME_MAPPINGS = { | 
					
						
						|  | "FreSca": "FreSca", | 
					
						
						|  | } | 
					
						
						|  |  |