File size: 3,362 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
from scipy.fftpack import dctn, idctn

class FrequencyAwareNoise:
    def __init__(self, config):
        self.config = config
        self.betas = torch.linspace(config.beta_start, config.beta_end, config.T)
        self.alphas = 1. - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        
        # Store as numpy arrays for DCT operations
        self.betas_np = self.betas.numpy()
        self.alphas_np = self.alphas.numpy()
        self.alpha_bars_np = self.alpha_bars.numpy()

    def apply_noise(self, x0, t, noise=None):
        """Add noise in frequency space (patch-wise DCT) - FIXED VERSION"""
        B, C, H, W = x0.shape
        device = x0.device
        xt = torch.zeros_like(x0)
        noise_spatial = torch.zeros_like(x0)  # Store the spatial domain noise for training
        patch_size = self.config.patch_size
        
        # Convert t to CPU for numpy operations
        t_cpu = t.cpu()
        
        for i in range(0, H, patch_size):
            for j in range(0, W, patch_size):
                patch = x0[:, :, i:i+patch_size, j:j+patch_size]
                patch_np = patch.cpu().numpy()
                
                # DCT per patch
                dct = dctn(patch_np, axes=(2, 3), norm='ortho')
                
                # Generate noise in DCT domain
                noise_dct = np.random.randn(*dct.shape)
                
                # Apply frequency-dependent scaling
                max_freq = dct.shape[2] + dct.shape[3] - 2
                for u in range(dct.shape[2]):
                    for v in range(dct.shape[3]):
                        freq_weight = 0.1 + 0.9 * (u + v) / max_freq
                        noise_dct[:, :, u, v] *= freq_weight
                
                # Get noise schedule parameters
                alpha_bars = self.alpha_bars_np[t_cpu]
                if alpha_bars.ndim == 0:
                    alpha_bars = np.array([alpha_bars])
                alpha_bars = alpha_bars.reshape(-1, 1, 1, 1)
                if alpha_bars.shape[0] != dct.shape[0]:
                    alpha_bars = np.broadcast_to(alpha_bars[0:1], (dct.shape[0], 1, 1, 1))
                
                # Apply noise in DCT domain
                noisy_dct = np.sqrt(alpha_bars) * dct + np.sqrt(1 - alpha_bars) * noise_dct
                noisy_patch = idctn(noisy_dct, axes=(2, 3), norm='ortho')
                
                # IMPORTANT: Convert the DCT noise back to spatial for model to predict
                noise_patch_spatial = idctn(noise_dct, axes=(2, 3), norm='ortho')
                
                xt[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noisy_patch).float().to(device)
                noise_spatial[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noise_patch_spatial).float().to(device)
        
        return xt, noise_spatial
    
    def debug_noise_stats(self, x0, t):
        """Debug function to check noise statistics"""
        xt, noise = self.apply_noise(x0, t)
        print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]")
        print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]")
        print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]")
        print(f"Noise std: {noise.std().item():.4f}")
        return xt, noise