VIVEK JAYARAM
commited on
Commit
·
95aa1d5
1
Parent(s):
c63740a
KL, categorical kl, and poisson noise
Browse files- cdim/diffusion/diffusion_pipeline.py +47 -8
- cdim/discrete_kl_loss.py +34 -0
- cdim/noise.py +15 -0
- inference.py +6 -2
- noise_configs/bimodal_noise_config.yaml +2 -0
- noise_configs/poisson_noise_config.yaml +1 -1
cdim/diffusion/diffusion_pipeline.py
CHANGED
@@ -2,6 +2,17 @@ import torch
|
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
from cdim.image_utils import randn_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
@torch.no_grad()
|
@@ -16,7 +27,8 @@ def run_diffusion(
|
|
16 |
K=5,
|
17 |
image_dim=256,
|
18 |
image_channels=3,
|
19 |
-
model_type="diffusers"
|
|
|
20 |
):
|
21 |
batch_size = noisy_observation.shape[0]
|
22 |
image_shape = (batch_size, image_channels, image_dim, image_dim)
|
@@ -44,13 +56,40 @@ def run_diffusion(
|
|
44 |
model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
|
45 |
x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
image -=
|
55 |
|
56 |
return image
|
|
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
from cdim.image_utils import randn_tensor
|
5 |
+
from cdim.discrete_kl_loss import discrete_kl_loss
|
6 |
+
|
7 |
+
def compute_kl_gaussian(residuals, sigma):
|
8 |
+
# Only 0 centered for now
|
9 |
+
if sigma == 0:
|
10 |
+
raise ValueError("Can't do KL Divergence when sigma is 0")
|
11 |
+
sample_mean = (residuals).mean()
|
12 |
+
sample_var = (((residuals - sample_mean) **2).mean())
|
13 |
+
kl_div = torch.log(sample_var**0.5 / sigma) + (sigma**2 + sample_mean**2) / (2*sample_var) - 0.5
|
14 |
+
print(f"KL Divergence {kl_div}")
|
15 |
+
return kl_div
|
16 |
|
17 |
|
18 |
@torch.no_grad()
|
|
|
27 |
K=5,
|
28 |
image_dim=256,
|
29 |
image_channels=3,
|
30 |
+
model_type="diffusers",
|
31 |
+
loss_type="l2"
|
32 |
):
|
33 |
batch_size = noisy_observation.shape[0]
|
34 |
image_shape = (batch_size, image_channels, image_dim, image_dim)
|
|
|
56 |
model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
|
57 |
x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
|
58 |
|
59 |
+
if loss_type == "l2" and noise_function.name == "gaussian":
|
60 |
+
distance = operator(x_0) - noisy_observation
|
61 |
+
if (distance ** 2).mean() < noise_function.sigma ** 2:
|
62 |
+
break
|
63 |
+
loss = ((distance) ** 2).mean()
|
64 |
+
print(f"L2 loss {loss}")
|
65 |
+
loss.backward()
|
66 |
+
|
67 |
+
elif loss_type == "kl" and noise_function.name == "gaussian":
|
68 |
+
diff = (operator(x_0) - noisy_observation) # Residuals
|
69 |
+
kl_div = compute_kl_gaussian(diff, noise_function.sigma)
|
70 |
+
kl_div.backward()
|
71 |
+
|
72 |
+
elif loss_type == "kl" and noise_function.name == "poisson":
|
73 |
+
residuals = (operator(x_0) * noise_function.rate - noisy_observation * noise_function.rate) * 127.5 # Residuals
|
74 |
+
x_0_pixel = operator((x_0 + 1) * 127.5)
|
75 |
+
mask = x_0_pixel > 2 # Avoid numeric issues with pixel values near 0
|
76 |
+
pearson = residuals[mask] / torch.sqrt(x_0_pixel[mask] * noise_function.rate)
|
77 |
+
pearson_flat = pearson.view(-1)
|
78 |
+
kl_div = compute_kl_gaussian(pearson_flat, 1.0)
|
79 |
+
kl_div.backward()
|
80 |
+
|
81 |
+
elif loss_type == "categorical_kl" and noise_function.name == "bimodal":
|
82 |
+
diff = (operator(x_0) - noisy_observation)
|
83 |
+
indices = operator(torch.ones(image.shape).to(device))
|
84 |
+
diff = diff[indices > 0] # Don't consider masked out pixels in the distribution
|
85 |
+
empirical_distribution = noise_function.sample_noise_distribution(image).to(device).view(-1)
|
86 |
+
loss = discrete_kl_loss(diff, empirical_distribution, num_bins=15)
|
87 |
+
print(f"Categorical KL {loss}")
|
88 |
+
loss.backward()
|
89 |
+
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
|
92 |
|
93 |
+
image -= 5 / torch.linalg.norm(image.grad) * image.grad
|
94 |
|
95 |
return image
|
cdim/discrete_kl_loss.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def discrete_kl_loss(pred, target, num_bins=20, epsilon=1e-8):
|
6 |
+
# Determine range for binning
|
7 |
+
with torch.no_grad():
|
8 |
+
combined = torch.cat([pred, target])
|
9 |
+
min_val = combined.min().item()
|
10 |
+
max_val = combined.max().item()
|
11 |
+
|
12 |
+
# Create bin edges
|
13 |
+
bin_edges = torch.linspace(min_val, max_val, num_bins + 1, device=pred.device)
|
14 |
+
bin_widths = bin_edges[1:] - bin_edges[:-1]
|
15 |
+
|
16 |
+
# Compute soft histogram
|
17 |
+
def soft_histogram(x):
|
18 |
+
x_expanded = x.unsqueeze(-1)
|
19 |
+
deltas = torch.abs(x_expanded - bin_edges[:-1].unsqueeze(0))
|
20 |
+
weights = torch.clamp(1 - deltas / bin_widths, min=0, max=1)
|
21 |
+
hist = weights.sum(dim=0) / len(x)
|
22 |
+
return hist
|
23 |
+
|
24 |
+
pred_hist = soft_histogram(pred)
|
25 |
+
target_hist = soft_histogram(target)
|
26 |
+
|
27 |
+
# Add epsilon and normalize
|
28 |
+
pred_probs = (pred_hist + epsilon) / (pred_hist.sum() + num_bins * epsilon)
|
29 |
+
target_probs = (target_hist + epsilon) / (target_hist.sum() + num_bins * epsilon)
|
30 |
+
|
31 |
+
# Compute KL divergence
|
32 |
+
kl_div = F.kl_div(pred_probs.log(), target_probs, reduction='sum')
|
33 |
+
|
34 |
+
return kl_div
|
cdim/noise.py
CHANGED
@@ -58,3 +58,18 @@ class PoissonNoise(Noise):
|
|
58 |
data = data * 2.0 - 1.0
|
59 |
data = data.clamp(-1, 1)
|
60 |
return data.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
data = data * 2.0 - 1.0
|
59 |
data = data.clamp(-1, 1)
|
60 |
return data.to(device)
|
61 |
+
|
62 |
+
|
63 |
+
@register_noise(name='bimodal')
|
64 |
+
class BimodalNoise(Noise):
|
65 |
+
def __init__(self, value):
|
66 |
+
self.value = value
|
67 |
+
self.name = 'bimodal'
|
68 |
+
|
69 |
+
def __call__(self, data):
|
70 |
+
noise = self.sample_noise_distribution(data)
|
71 |
+
return data + noise.to(data.device)
|
72 |
+
|
73 |
+
def sample_noise_distribution(self, data):
|
74 |
+
return (torch.randint(low=0, high=2, size=data.shape) * 2 - 1) * self.value
|
75 |
+
|
inference.py
CHANGED
@@ -87,7 +87,8 @@ def main(args):
|
|
87 |
noisy_measurement, operator, noise_function, device,
|
88 |
num_inference_steps=args.T,
|
89 |
K=args.K,
|
90 |
-
model_type=model_type
|
|
|
91 |
print(f"total time {time.time() - t0}")
|
92 |
|
93 |
save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
|
@@ -97,11 +98,14 @@ if __name__ == '__main__':
|
|
97 |
parser.add_argument("input_image", type=str)
|
98 |
parser.add_argument("T", type=int)
|
99 |
parser.add_argument("K", type=int)
|
100 |
-
parser.add_argument("model", type=str)
|
101 |
parser.add_argument("operator_config", type=str)
|
102 |
parser.add_argument("noise_config", type=str)
|
103 |
parser.add_argument("model_config", type=str)
|
104 |
parser.add_argument("--output-dir", default=".", type=str)
|
|
|
|
|
|
|
|
|
105 |
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
|
106 |
|
107 |
main(parser.parse_args())
|
|
|
87 |
noisy_measurement, operator, noise_function, device,
|
88 |
num_inference_steps=args.T,
|
89 |
K=args.K,
|
90 |
+
model_type=model_type,
|
91 |
+
loss_type=args.loss)
|
92 |
print(f"total time {time.time() - t0}")
|
93 |
|
94 |
save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
|
|
|
98 |
parser.add_argument("input_image", type=str)
|
99 |
parser.add_argument("T", type=int)
|
100 |
parser.add_argument("K", type=int)
|
|
|
101 |
parser.add_argument("operator_config", type=str)
|
102 |
parser.add_argument("noise_config", type=str)
|
103 |
parser.add_argument("model_config", type=str)
|
104 |
parser.add_argument("--output-dir", default=".", type=str)
|
105 |
+
parser.add_argument("--loss", type=str,
|
106 |
+
choices=['l2', 'kl', 'categorical_kl'], default='l2',
|
107 |
+
help="Algorithm to use. Options: 'l2', 'kl', 'categorical_kl'. Default is 'l2'."
|
108 |
+
)
|
109 |
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
|
110 |
|
111 |
main(parser.parse_args())
|
noise_configs/bimodal_noise_config.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
name: "bimodal"
|
2 |
+
value: 0.75
|
noise_configs/poisson_noise_config.yaml
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
name: poisson
|
2 |
-
rate: 0.
|
|
|
1 |
name: poisson
|
2 |
+
rate: 0.05
|