VIVEK JAYARAM commited on
Commit
95aa1d5
·
1 Parent(s): c63740a

KL, categorical kl, and poisson noise

Browse files
cdim/diffusion/diffusion_pipeline.py CHANGED
@@ -2,6 +2,17 @@ import torch
2
  from tqdm import tqdm
3
 
4
  from cdim.image_utils import randn_tensor
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  @torch.no_grad()
@@ -16,7 +27,8 @@ def run_diffusion(
16
  K=5,
17
  image_dim=256,
18
  image_channels=3,
19
- model_type="diffusers"
 
20
  ):
21
  batch_size = noisy_observation.shape[0]
22
  image_shape = (batch_size, image_channels, image_dim, image_dim)
@@ -44,13 +56,40 @@ def run_diffusion(
44
  model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
45
  x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
46
 
47
- distance = operator(x_0) - noisy_observation
48
- if (distance ** 2).mean() < noise_function.sigma ** 2:
49
- break
50
- loss = ((distance) ** 2).mean()
51
- print(loss.mean())
52
- loss.mean().backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- image -= 15 / torch.linalg.norm(image.grad) * image.grad
55
 
56
  return image
 
2
  from tqdm import tqdm
3
 
4
  from cdim.image_utils import randn_tensor
5
+ from cdim.discrete_kl_loss import discrete_kl_loss
6
+
7
+ def compute_kl_gaussian(residuals, sigma):
8
+ # Only 0 centered for now
9
+ if sigma == 0:
10
+ raise ValueError("Can't do KL Divergence when sigma is 0")
11
+ sample_mean = (residuals).mean()
12
+ sample_var = (((residuals - sample_mean) **2).mean())
13
+ kl_div = torch.log(sample_var**0.5 / sigma) + (sigma**2 + sample_mean**2) / (2*sample_var) - 0.5
14
+ print(f"KL Divergence {kl_div}")
15
+ return kl_div
16
 
17
 
18
  @torch.no_grad()
 
27
  K=5,
28
  image_dim=256,
29
  image_channels=3,
30
+ model_type="diffusers",
31
+ loss_type="l2"
32
  ):
33
  batch_size = noisy_observation.shape[0]
34
  image_shape = (batch_size, image_channels, image_dim, image_dim)
 
56
  model_output = model_output.sample if model_type == "diffusers" else model_output[:, :3]
57
  x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
58
 
59
+ if loss_type == "l2" and noise_function.name == "gaussian":
60
+ distance = operator(x_0) - noisy_observation
61
+ if (distance ** 2).mean() < noise_function.sigma ** 2:
62
+ break
63
+ loss = ((distance) ** 2).mean()
64
+ print(f"L2 loss {loss}")
65
+ loss.backward()
66
+
67
+ elif loss_type == "kl" and noise_function.name == "gaussian":
68
+ diff = (operator(x_0) - noisy_observation) # Residuals
69
+ kl_div = compute_kl_gaussian(diff, noise_function.sigma)
70
+ kl_div.backward()
71
+
72
+ elif loss_type == "kl" and noise_function.name == "poisson":
73
+ residuals = (operator(x_0) * noise_function.rate - noisy_observation * noise_function.rate) * 127.5 # Residuals
74
+ x_0_pixel = operator((x_0 + 1) * 127.5)
75
+ mask = x_0_pixel > 2 # Avoid numeric issues with pixel values near 0
76
+ pearson = residuals[mask] / torch.sqrt(x_0_pixel[mask] * noise_function.rate)
77
+ pearson_flat = pearson.view(-1)
78
+ kl_div = compute_kl_gaussian(pearson_flat, 1.0)
79
+ kl_div.backward()
80
+
81
+ elif loss_type == "categorical_kl" and noise_function.name == "bimodal":
82
+ diff = (operator(x_0) - noisy_observation)
83
+ indices = operator(torch.ones(image.shape).to(device))
84
+ diff = diff[indices > 0] # Don't consider masked out pixels in the distribution
85
+ empirical_distribution = noise_function.sample_noise_distribution(image).to(device).view(-1)
86
+ loss = discrete_kl_loss(diff, empirical_distribution, num_bins=15)
87
+ print(f"Categorical KL {loss}")
88
+ loss.backward()
89
+
90
+ else:
91
+ raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
92
 
93
+ image -= 5 / torch.linalg.norm(image.grad) * image.grad
94
 
95
  return image
cdim/discrete_kl_loss.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def discrete_kl_loss(pred, target, num_bins=20, epsilon=1e-8):
6
+ # Determine range for binning
7
+ with torch.no_grad():
8
+ combined = torch.cat([pred, target])
9
+ min_val = combined.min().item()
10
+ max_val = combined.max().item()
11
+
12
+ # Create bin edges
13
+ bin_edges = torch.linspace(min_val, max_val, num_bins + 1, device=pred.device)
14
+ bin_widths = bin_edges[1:] - bin_edges[:-1]
15
+
16
+ # Compute soft histogram
17
+ def soft_histogram(x):
18
+ x_expanded = x.unsqueeze(-1)
19
+ deltas = torch.abs(x_expanded - bin_edges[:-1].unsqueeze(0))
20
+ weights = torch.clamp(1 - deltas / bin_widths, min=0, max=1)
21
+ hist = weights.sum(dim=0) / len(x)
22
+ return hist
23
+
24
+ pred_hist = soft_histogram(pred)
25
+ target_hist = soft_histogram(target)
26
+
27
+ # Add epsilon and normalize
28
+ pred_probs = (pred_hist + epsilon) / (pred_hist.sum() + num_bins * epsilon)
29
+ target_probs = (target_hist + epsilon) / (target_hist.sum() + num_bins * epsilon)
30
+
31
+ # Compute KL divergence
32
+ kl_div = F.kl_div(pred_probs.log(), target_probs, reduction='sum')
33
+
34
+ return kl_div
cdim/noise.py CHANGED
@@ -58,3 +58,18 @@ class PoissonNoise(Noise):
58
  data = data * 2.0 - 1.0
59
  data = data.clamp(-1, 1)
60
  return data.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  data = data * 2.0 - 1.0
59
  data = data.clamp(-1, 1)
60
  return data.to(device)
61
+
62
+
63
+ @register_noise(name='bimodal')
64
+ class BimodalNoise(Noise):
65
+ def __init__(self, value):
66
+ self.value = value
67
+ self.name = 'bimodal'
68
+
69
+ def __call__(self, data):
70
+ noise = self.sample_noise_distribution(data)
71
+ return data + noise.to(data.device)
72
+
73
+ def sample_noise_distribution(self, data):
74
+ return (torch.randint(low=0, high=2, size=data.shape) * 2 - 1) * self.value
75
+
inference.py CHANGED
@@ -87,7 +87,8 @@ def main(args):
87
  noisy_measurement, operator, noise_function, device,
88
  num_inference_steps=args.T,
89
  K=args.K,
90
- model_type=model_type)
 
91
  print(f"total time {time.time() - t0}")
92
 
93
  save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
@@ -97,11 +98,14 @@ if __name__ == '__main__':
97
  parser.add_argument("input_image", type=str)
98
  parser.add_argument("T", type=int)
99
  parser.add_argument("K", type=int)
100
- parser.add_argument("model", type=str)
101
  parser.add_argument("operator_config", type=str)
102
  parser.add_argument("noise_config", type=str)
103
  parser.add_argument("model_config", type=str)
104
  parser.add_argument("--output-dir", default=".", type=str)
 
 
 
 
105
  parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
106
 
107
  main(parser.parse_args())
 
87
  noisy_measurement, operator, noise_function, device,
88
  num_inference_steps=args.T,
89
  K=args.K,
90
+ model_type=model_type,
91
+ loss_type=args.loss)
92
  print(f"total time {time.time() - t0}")
93
 
94
  save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
 
98
  parser.add_argument("input_image", type=str)
99
  parser.add_argument("T", type=int)
100
  parser.add_argument("K", type=int)
 
101
  parser.add_argument("operator_config", type=str)
102
  parser.add_argument("noise_config", type=str)
103
  parser.add_argument("model_config", type=str)
104
  parser.add_argument("--output-dir", default=".", type=str)
105
+ parser.add_argument("--loss", type=str,
106
+ choices=['l2', 'kl', 'categorical_kl'], default='l2',
107
+ help="Algorithm to use. Options: 'l2', 'kl', 'categorical_kl'. Default is 'l2'."
108
+ )
109
  parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
110
 
111
  main(parser.parse_args())
noise_configs/bimodal_noise_config.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ name: "bimodal"
2
+ value: 0.75
noise_configs/poisson_noise_config.yaml CHANGED
@@ -1,2 +1,2 @@
1
  name: poisson
2
- rate: 0.1
 
1
  name: poisson
2
+ rate: 0.05