File size: 6,990 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
from dataloader import get_dataloaders
import numpy as np

def hybrid_generation():
    """Hybrid approach: Use model as super-denoiser rather than pure generator"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    checkpoint = torch.load('model_final.pth', map_location=device)
    config = Config()
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    model.load_state_dict(checkpoint)
    model.eval()
    
    # Load real training data for smart initialization
    train_loader, _ = get_dataloaders(config)
    real_batch, _ = next(iter(train_loader))
    real_images = real_batch[:8].to(device)
    
    print("=== HYBRID GENERATION APPROACH ===")
    
    with torch.no_grad():
        # Method 1: Smart noise initialization
        print("\n--- Method 1: Smart Noise Initialization ---")
        
        # Initialize with noise that has similar statistics to training data
        smart_noise = torch.randn(4, 3, 64, 64, device=device)
        smart_noise = smart_noise * real_images.std().item()  # Match training data std
        smart_noise = smart_noise + real_images.mean().item()  # Match training data mean
        smart_noise = torch.clamp(smart_noise, -1, 1)
        
        print(f"Smart noise stats: mean={smart_noise.mean():.3f}, std={smart_noise.std():.3f}")
        
        # Apply progressive denoising
        timesteps = [150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1]
        x = smart_noise.clone()
        
        for t_val in timesteps:
            t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
            predicted_noise = model(x, t_tensor)
            
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.7) / np.sqrt(alpha_bar_t)
            x = torch.clamp(x, -1, 1)
        
        # Save result
        smart_display = torch.clamp((x + 1) / 2, 0, 1)
        smart_grid = make_grid(smart_display, nrow=2, normalize=False)
        save_image(smart_grid, "smart_noise_generation.png")
        print(f"Smart noise result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
        print("Saved to smart_noise_generation.png")
        
        # Method 2: Blended real images + denoising
        print("\n--- Method 2: Blended Real Images ---")
        
        # Create new combinations by blending random real images
        indices = torch.randint(0, len(real_images), (4, 3))  # Pick 3 random images for each output
        weights = torch.rand(4, 3, device=device)
        weights = weights / weights.sum(dim=1, keepdim=True)  # Normalize weights
        
        blended = torch.zeros(4, 3, 64, 64, device=device)
        for i in range(4):
            for j in range(3):
                blended[i] += weights[i, j] * real_images[indices[i, j]]
        
        # Add some noise to make it more interesting
        noise = torch.randn_like(blended) * 0.15
        blended = blended + noise
        blended = torch.clamp(blended, -1, 1)
        
        # Light denoising to clean up
        light_timesteps = [80, 60, 40, 25, 12, 5, 1]
        x = blended.clone()
        
        for t_val in light_timesteps:
            t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
            predicted_noise = model(x, t_tensor)
            
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t)
            x = torch.clamp(x, -1, 1)
        
        # Save result
        blended_display = torch.clamp((x + 1) / 2, 0, 1)
        blended_grid = make_grid(blended_display, nrow=2, normalize=False)
        save_image(blended_grid, "blended_generation.png")
        print(f"Blended result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
        print("Saved to blended_generation.png")
        
        # Method 3: Frequency-domain initialization
        print("\n--- Method 3: Frequency-Domain Initialization ---")
        
        # Start with structured noise in frequency domain, then convert to spatial
        from scipy.fftpack import dctn, idctn
        
        freq_images = torch.zeros(4, 3, 64, 64, device=device)
        
        for i in range(4):
            for c in range(3):
                # Create structured frequency pattern
                freq_pattern = np.zeros((64, 64))
                
                # Add some low-frequency components (overall shape/color)
                for u in range(0, 8):
                    for v in range(0, 8):
                        freq_pattern[u, v] = np.random.randn() * (1.0 / (1 + u + v))
                
                # Add some mid-frequency components (textures)
                for u in range(8, 20):
                    for v in range(8, 20):
                        freq_pattern[u, v] = np.random.randn() * 0.1
                
                # Convert to spatial domain
                spatial = idctn(freq_pattern, norm='ortho')
                freq_images[i, c] = torch.from_numpy(spatial).float()
        
        # Normalize to training data range
        freq_images = freq_images.to(device)
        freq_images = freq_images - freq_images.mean()
        freq_images = freq_images / freq_images.std() * real_images.std()
        freq_images = torch.clamp(freq_images, -1, 1)
        
        # Apply denoising
        freq_timesteps = [100, 75, 55, 40, 28, 18, 10, 4, 1]
        x = freq_images.clone()
        
        for t_val in freq_timesteps:
            t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
            predicted_noise = model(x, t_tensor)
            
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.6) / np.sqrt(alpha_bar_t)
            x = torch.clamp(x, -1, 1)
        
        # Save result
        freq_display = torch.clamp((x + 1) / 2, 0, 1)
        freq_grid = make_grid(freq_display, nrow=2, normalize=False)
        save_image(freq_grid, "frequency_generation.png")
        print(f"Frequency result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
        print("Saved to frequency_generation.png")
        
        print("\n=== RESULTS ===")
        print("Generated files:")
        print("- smart_noise_generation.png (noise matching training stats)")
        print("- blended_generation.png (combinations of real images)")
        print("- frequency_generation.png (frequency-domain initialization)")
        print("\nYour model works as a super-denoiser!")
        print("It can clean up any reasonable starting point to look more image-like.")

if __name__ == "__main__":
    hybrid_generation()