File size: 6,990 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
from dataloader import get_dataloaders
import numpy as np
def hybrid_generation():
"""Hybrid approach: Use model as super-denoiser rather than pure generator"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
model.eval()
# Load real training data for smart initialization
train_loader, _ = get_dataloaders(config)
real_batch, _ = next(iter(train_loader))
real_images = real_batch[:8].to(device)
print("=== HYBRID GENERATION APPROACH ===")
with torch.no_grad():
# Method 1: Smart noise initialization
print("\n--- Method 1: Smart Noise Initialization ---")
# Initialize with noise that has similar statistics to training data
smart_noise = torch.randn(4, 3, 64, 64, device=device)
smart_noise = smart_noise * real_images.std().item() # Match training data std
smart_noise = smart_noise + real_images.mean().item() # Match training data mean
smart_noise = torch.clamp(smart_noise, -1, 1)
print(f"Smart noise stats: mean={smart_noise.mean():.3f}, std={smart_noise.std():.3f}")
# Apply progressive denoising
timesteps = [150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1]
x = smart_noise.clone()
for t_val in timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
predicted_noise = model(x, t_tensor)
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.7) / np.sqrt(alpha_bar_t)
x = torch.clamp(x, -1, 1)
# Save result
smart_display = torch.clamp((x + 1) / 2, 0, 1)
smart_grid = make_grid(smart_display, nrow=2, normalize=False)
save_image(smart_grid, "smart_noise_generation.png")
print(f"Smart noise result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
print("Saved to smart_noise_generation.png")
# Method 2: Blended real images + denoising
print("\n--- Method 2: Blended Real Images ---")
# Create new combinations by blending random real images
indices = torch.randint(0, len(real_images), (4, 3)) # Pick 3 random images for each output
weights = torch.rand(4, 3, device=device)
weights = weights / weights.sum(dim=1, keepdim=True) # Normalize weights
blended = torch.zeros(4, 3, 64, 64, device=device)
for i in range(4):
for j in range(3):
blended[i] += weights[i, j] * real_images[indices[i, j]]
# Add some noise to make it more interesting
noise = torch.randn_like(blended) * 0.15
blended = blended + noise
blended = torch.clamp(blended, -1, 1)
# Light denoising to clean up
light_timesteps = [80, 60, 40, 25, 12, 5, 1]
x = blended.clone()
for t_val in light_timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
predicted_noise = model(x, t_tensor)
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t)
x = torch.clamp(x, -1, 1)
# Save result
blended_display = torch.clamp((x + 1) / 2, 0, 1)
blended_grid = make_grid(blended_display, nrow=2, normalize=False)
save_image(blended_grid, "blended_generation.png")
print(f"Blended result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
print("Saved to blended_generation.png")
# Method 3: Frequency-domain initialization
print("\n--- Method 3: Frequency-Domain Initialization ---")
# Start with structured noise in frequency domain, then convert to spatial
from scipy.fftpack import dctn, idctn
freq_images = torch.zeros(4, 3, 64, 64, device=device)
for i in range(4):
for c in range(3):
# Create structured frequency pattern
freq_pattern = np.zeros((64, 64))
# Add some low-frequency components (overall shape/color)
for u in range(0, 8):
for v in range(0, 8):
freq_pattern[u, v] = np.random.randn() * (1.0 / (1 + u + v))
# Add some mid-frequency components (textures)
for u in range(8, 20):
for v in range(8, 20):
freq_pattern[u, v] = np.random.randn() * 0.1
# Convert to spatial domain
spatial = idctn(freq_pattern, norm='ortho')
freq_images[i, c] = torch.from_numpy(spatial).float()
# Normalize to training data range
freq_images = freq_images.to(device)
freq_images = freq_images - freq_images.mean()
freq_images = freq_images / freq_images.std() * real_images.std()
freq_images = torch.clamp(freq_images, -1, 1)
# Apply denoising
freq_timesteps = [100, 75, 55, 40, 28, 18, 10, 4, 1]
x = freq_images.clone()
for t_val in freq_timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
predicted_noise = model(x, t_tensor)
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.6) / np.sqrt(alpha_bar_t)
x = torch.clamp(x, -1, 1)
# Save result
freq_display = torch.clamp((x + 1) / 2, 0, 1)
freq_grid = make_grid(freq_display, nrow=2, normalize=False)
save_image(freq_grid, "frequency_generation.png")
print(f"Frequency result: range=[{x.min():.3f}, {x.max():.3f}], std={x.std():.3f}")
print("Saved to frequency_generation.png")
print("\n=== RESULTS ===")
print("Generated files:")
print("- smart_noise_generation.png (noise matching training stats)")
print("- blended_generation.png (combinations of real images)")
print("- frequency_generation.png (frequency-domain initialization)")
print("\nYour model works as a super-denoiser!")
print("It can clean up any reasonable starting point to look more image-like.")
if __name__ == "__main__":
hybrid_generation()
|