VIVEK JAYARAM commited on
Commit
d8f7287
β€’
1 Parent(s): 5b2cc7a

Initial operators working with masking

Browse files
.gitignore CHANGED
@@ -160,3 +160,5 @@ cython_debug/
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
 
 
 
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
163
+
164
+ *.DS_Store
README.md CHANGED
@@ -1,2 +1,10 @@
1
  # cdim
2
  Constrained Diffusion Implicit Models
 
 
 
 
 
 
 
 
 
1
  # cdim
2
  Constrained Diffusion Implicit Models
3
+
4
+ conda create -n cdim python=3.11
5
+
6
+ conda activate cdim
7
+
8
+ pip install -r requirements.txt
9
+
10
+ pip install torch==2.4.1+cu124 torchvision-0.19.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
cdim/image_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import ToPILImage
2
+
3
+ def save_to_image(tensor, filename):
4
+ """
5
+ Saves a torch tensor to an image.
6
+ The image assumed to be (1, 3, H, W)
7
+ with values between (-1, 1)
8
+ """
9
+ to_save = (tensor[0] + 1) / 2
10
+ to_save = to_save.clamp(0, 1)
11
+
12
+ # Convert to PIL Image
13
+ transform = ToPILImage()
14
+ img = transform(to_save)
15
+
16
+ # Save the image
17
+ img.save(filename)
cdim/noise.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code based on https://github.com/DPS2022/diffusion-posterior-sampling
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+
5
+ __NOISE__ = {}
6
+
7
+
8
+ def register_noise(name: str):
9
+ def wrapper(cls):
10
+ if __NOISE__.get(name, None):
11
+ raise NameError(f"Name {name} is already defined!")
12
+ __NOISE__[name] = cls
13
+ return cls
14
+ return wrapper
15
+
16
+
17
+ def get_noise(name: str, **kwargs):
18
+ if __NOISE__.get(name, None) is None:
19
+ raise NameError(f"Name {name} is not defined.")
20
+ noiser = __NOISE__[name](**kwargs)
21
+ noiser.__name__ = name
22
+ return noiser
23
+
24
+
25
+ class Noise(ABC):
26
+ def __call__(self, data):
27
+ return self.forward(data)
28
+
29
+ @abstractmethod
30
+ def __call__(self, data):
31
+ pass
32
+
33
+ @register_noise(name='gaussian')
34
+ class GaussianNoise(Noise):
35
+ def __init__(self, sigma):
36
+ self.sigma = sigma
37
+
38
+ def __call__(self, data):
39
+ # Important! We scale sigma by 2 because the config assumes images are in [0, 1]
40
+ # but actually this model uses images in [-1, 1]
41
+ return data + torch.randn_like(data, device=data.device) * self.sigma * 2
42
+
43
+
44
+ @register_noise(name='poisson')
45
+ class PoissonNoise(Noise):
46
+ def __init__(self, rate):
47
+ self.rate = rate
48
+
49
+ def __call__(self, data):
50
+ import numpy as np
51
+ data = (data + 1.0) / 2.0
52
+ data = data.clamp(0, 1)
53
+ device = data.device
54
+ data = data.detach().cpu()
55
+ data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate)
56
+ data = data * 2.0 - 1.0
57
+ data = data.clamp(-1, 1)
58
+ return data.to(device)
cdim/noise_configs/gaussian_noise.yaml DELETED
@@ -1,2 +0,0 @@
1
- name: gaussian
2
- sigma: 0.05
 
 
 
cdim/operators/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code based on https://github.com/DPS2022/diffusion-posterior-sampling
2
+ from abc import ABC, abstractmethod
3
+
4
+ __OPERATOR__ = {}
5
+
6
+
7
+ def register_operator(name: str):
8
+ def wrapper(cls):
9
+ if __OPERATOR__.get(name, None):
10
+ raise NameError(f"Name {name} is already registered!")
11
+ __OPERATOR__[name] = cls
12
+ return cls
13
+ return wrapper
14
+
15
+
16
+ def get_operator(name: str, **kwargs):
17
+ if __OPERATOR__.get(name, None) is None:
18
+ raise NameError(f"Name {name} is not defined.")
19
+ return __OPERATOR__[name](**kwargs)
20
+
21
+
22
+ # Import everything to make sure they register
23
+ from .random_box_masker import RandomBoxMasker
24
+ from .identity_operator import IdentityOperator
cdim/operators/identity_operator.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from cdim.operators import register_operator
2
+
3
+ @register_operator(name='identity')
4
+ class IdentityOperator:
5
+ def __init__(self, device):
6
+ self.device = device
7
+
8
+ def __call__(self, data):
9
+ return data
cdim/operators/random_box_masker.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from cdim.operators import register_operator
4
+
5
+ @register_operator(name='box_inpainting')
6
+ class RandomBoxMasker:
7
+ def __init__(self, height=256, width=256, channels=3, box_size=128, device='cpu'):
8
+ """
9
+ Initialize the ConsistentBoxMasker with random box positioning.
10
+
11
+ Args:
12
+ height (int): Height of the input tensors (default: 256)
13
+ width (int): Width of the input tensors (default: 256)
14
+ channels (int): Number of channels in the input tensors (default: 3)
15
+ box_size (int): Size of the box to mask (default: 128)
16
+ device (str): Device to create the mask on (default: 'cpu')
17
+ """
18
+ self.height = height
19
+ self.width = width
20
+ self.channels = channels
21
+ self.box_size = min(box_size, height, width) # Ensure box_size doesn't exceed image dimensions
22
+ self.device = device
23
+
24
+ # Create a binary mask for box selection
25
+ self.mask = torch.ones((1, channels, height, width), device=device)
26
+
27
+ # Randomly calculate the top-left corner of the box
28
+ max_y = height - self.box_size
29
+ max_x = width - self.box_size
30
+
31
+ start_y = torch.randint(0, max_y + 1, (1,)).item()
32
+ start_x = torch.randint(0, max_x + 1, (1,)).item()
33
+
34
+ # Set the box area in the mask to 0
35
+ self.mask[0, :, start_y:start_y+self.box_size, start_x:start_x+self.box_size] = 0
36
+
37
+ def __call__(self, tensor):
38
+ """
39
+ Apply the consistent box masking to the input tensor.
40
+
41
+ Args:
42
+ tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
43
+
44
+ Returns:
45
+ torch.Tensor: Tensor with the same shape as input, but with the box area masked out
46
+ """
47
+ b, c, h, w = tensor.shape
48
+ assert c == self.channels and h == self.height and w == self.width, \
49
+ f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
50
+
51
+ # Move the mask to the same device as the input tensor if necessary
52
+ if tensor.device != self.mask.device:
53
+ self.mask = self.mask.to(tensor.device)
54
+
55
+ # Apply the mask to the input tensor
56
+ return tensor * self.mask
inference.py CHANGED
@@ -1,12 +1,55 @@
1
  import argparse
2
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def main(args):
 
 
 
 
6
  os.makedirs(args.output_dir, exist_ok=True)
 
7
 
8
-
 
 
 
9
 
 
 
 
 
 
 
 
 
10
 
11
  if __name__ == '__main__':
12
  parser = argparse.ArgumentParser()
@@ -17,4 +60,6 @@ if __name__ == '__main__':
17
  parser.add_argument("operator_config", type=str)
18
  parser.add_argument("noise_config", type=str)
19
  parser.add_argument("--output-dir", default=".", type=str)
 
 
20
  main(parser.parse_args())
 
1
  import argparse
2
  import os
3
+ import yaml
4
+
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+
9
+ from cdim.noise import get_noise
10
+ from cdim.operators import get_operator
11
+ from cdim.image_utils import save_to_image
12
+
13
+
14
+ def load_image(path):
15
+ """
16
+ Load the image and normalize to [-1, 1]
17
+ """
18
+ original_image = Image.open(path)
19
+
20
+ # Resize if needed
21
+ original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
22
+ original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
23
+ return (original_image / 127.5 - 1.0).to(torch.float)
24
+
25
+
26
+ def load_yaml(file_path: str) -> dict:
27
+ with open(file_path) as f:
28
+ config = yaml.load(f, Loader=yaml.FullLoader)
29
+ return config
30
 
31
 
32
  def main(args):
33
+ device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
34
+ print(f"Using device {device_str}")
35
+ device = torch.device(device_str)
36
+
37
  os.makedirs(args.output_dir, exist_ok=True)
38
+ original_image = load_image(args.input_image).to(device)
39
 
40
+ # Load the noise function
41
+ noise_config = load_yaml(args.noise_config)
42
+ noise_function = get_noise(**noise_config)
43
+ print(noise_function)
44
 
45
+ # Load the measurement function A
46
+ operator_config = load_yaml(args.operator_config)
47
+ operator_config["device"] = device
48
+ operator = get_operator(**operator_config)
49
+ print(operator)
50
+
51
+ noisy_measurement = noise_function(operator(original_image))
52
+ save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
53
 
54
  if __name__ == '__main__':
55
  parser = argparse.ArgumentParser()
 
60
  parser.add_argument("operator_config", type=str)
61
  parser.add_argument("noise_config", type=str)
62
  parser.add_argument("--output-dir", default=".", type=str)
63
+ parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
64
+
65
  main(parser.parse_args())
noise_configs/gaussian_noise_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: gaussian
2
+ sigma: 0.05
3
+ # Important! This noise is assumed to be for images in [0, 1]
{cdim/noise_configs β†’ noise_configs}/poisson_noise_config.yaml RENAMED
File without changes
noisy_measurement.png ADDED
operator_configs/box_inpainting_config.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: box_inpainting
2
+ box_size: 128
3
+ height: 256
4
+ width: 256
5
+ channels: 3
operator_configs/identity_operator_config.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ name: identity
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ numpy==2.1.2
2
+ Pillow==11.0.0