VIVEK JAYARAM
commited on
Commit
β’
d8f7287
1
Parent(s):
5b2cc7a
Initial operators working with masking
Browse files- .gitignore +2 -0
- README.md +8 -0
- cdim/image_utils.py +17 -0
- cdim/noise.py +58 -0
- cdim/noise_configs/gaussian_noise.yaml +0 -2
- cdim/operators/__init__.py +24 -0
- cdim/operators/identity_operator.py +9 -0
- cdim/operators/random_box_masker.py +56 -0
- inference.py +46 -1
- noise_configs/gaussian_noise_config.yaml +3 -0
- {cdim/noise_configs β noise_configs}/poisson_noise_config.yaml +0 -0
- noisy_measurement.png +0 -0
- operator_configs/box_inpainting_config.yaml +5 -0
- operator_configs/identity_operator_config.yaml +1 -0
- requirements.txt +2 -0
.gitignore
CHANGED
@@ -160,3 +160,5 @@ cython_debug/
|
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
#.idea/
|
|
|
|
|
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
#.idea/
|
163 |
+
|
164 |
+
*.DS_Store
|
README.md
CHANGED
@@ -1,2 +1,10 @@
|
|
1 |
# cdim
|
2 |
Constrained Diffusion Implicit Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# cdim
|
2 |
Constrained Diffusion Implicit Models
|
3 |
+
|
4 |
+
conda create -n cdim python=3.11
|
5 |
+
|
6 |
+
conda activate cdim
|
7 |
+
|
8 |
+
pip install -r requirements.txt
|
9 |
+
|
10 |
+
pip install torch==2.4.1+cu124 torchvision-0.19.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
cdim/image_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import ToPILImage
|
2 |
+
|
3 |
+
def save_to_image(tensor, filename):
|
4 |
+
"""
|
5 |
+
Saves a torch tensor to an image.
|
6 |
+
The image assumed to be (1, 3, H, W)
|
7 |
+
with values between (-1, 1)
|
8 |
+
"""
|
9 |
+
to_save = (tensor[0] + 1) / 2
|
10 |
+
to_save = to_save.clamp(0, 1)
|
11 |
+
|
12 |
+
# Convert to PIL Image
|
13 |
+
transform = ToPILImage()
|
14 |
+
img = transform(to_save)
|
15 |
+
|
16 |
+
# Save the image
|
17 |
+
img.save(filename)
|
cdim/noise.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code based on https://github.com/DPS2022/diffusion-posterior-sampling
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
import torch
|
4 |
+
|
5 |
+
__NOISE__ = {}
|
6 |
+
|
7 |
+
|
8 |
+
def register_noise(name: str):
|
9 |
+
def wrapper(cls):
|
10 |
+
if __NOISE__.get(name, None):
|
11 |
+
raise NameError(f"Name {name} is already defined!")
|
12 |
+
__NOISE__[name] = cls
|
13 |
+
return cls
|
14 |
+
return wrapper
|
15 |
+
|
16 |
+
|
17 |
+
def get_noise(name: str, **kwargs):
|
18 |
+
if __NOISE__.get(name, None) is None:
|
19 |
+
raise NameError(f"Name {name} is not defined.")
|
20 |
+
noiser = __NOISE__[name](**kwargs)
|
21 |
+
noiser.__name__ = name
|
22 |
+
return noiser
|
23 |
+
|
24 |
+
|
25 |
+
class Noise(ABC):
|
26 |
+
def __call__(self, data):
|
27 |
+
return self.forward(data)
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def __call__(self, data):
|
31 |
+
pass
|
32 |
+
|
33 |
+
@register_noise(name='gaussian')
|
34 |
+
class GaussianNoise(Noise):
|
35 |
+
def __init__(self, sigma):
|
36 |
+
self.sigma = sigma
|
37 |
+
|
38 |
+
def __call__(self, data):
|
39 |
+
# Important! We scale sigma by 2 because the config assumes images are in [0, 1]
|
40 |
+
# but actually this model uses images in [-1, 1]
|
41 |
+
return data + torch.randn_like(data, device=data.device) * self.sigma * 2
|
42 |
+
|
43 |
+
|
44 |
+
@register_noise(name='poisson')
|
45 |
+
class PoissonNoise(Noise):
|
46 |
+
def __init__(self, rate):
|
47 |
+
self.rate = rate
|
48 |
+
|
49 |
+
def __call__(self, data):
|
50 |
+
import numpy as np
|
51 |
+
data = (data + 1.0) / 2.0
|
52 |
+
data = data.clamp(0, 1)
|
53 |
+
device = data.device
|
54 |
+
data = data.detach().cpu()
|
55 |
+
data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate)
|
56 |
+
data = data * 2.0 - 1.0
|
57 |
+
data = data.clamp(-1, 1)
|
58 |
+
return data.to(device)
|
cdim/noise_configs/gaussian_noise.yaml
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
name: gaussian
|
2 |
-
sigma: 0.05
|
|
|
|
|
|
cdim/operators/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code based on https://github.com/DPS2022/diffusion-posterior-sampling
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
__OPERATOR__ = {}
|
5 |
+
|
6 |
+
|
7 |
+
def register_operator(name: str):
|
8 |
+
def wrapper(cls):
|
9 |
+
if __OPERATOR__.get(name, None):
|
10 |
+
raise NameError(f"Name {name} is already registered!")
|
11 |
+
__OPERATOR__[name] = cls
|
12 |
+
return cls
|
13 |
+
return wrapper
|
14 |
+
|
15 |
+
|
16 |
+
def get_operator(name: str, **kwargs):
|
17 |
+
if __OPERATOR__.get(name, None) is None:
|
18 |
+
raise NameError(f"Name {name} is not defined.")
|
19 |
+
return __OPERATOR__[name](**kwargs)
|
20 |
+
|
21 |
+
|
22 |
+
# Import everything to make sure they register
|
23 |
+
from .random_box_masker import RandomBoxMasker
|
24 |
+
from .identity_operator import IdentityOperator
|
cdim/operators/identity_operator.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cdim.operators import register_operator
|
2 |
+
|
3 |
+
@register_operator(name='identity')
|
4 |
+
class IdentityOperator:
|
5 |
+
def __init__(self, device):
|
6 |
+
self.device = device
|
7 |
+
|
8 |
+
def __call__(self, data):
|
9 |
+
return data
|
cdim/operators/random_box_masker.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from cdim.operators import register_operator
|
4 |
+
|
5 |
+
@register_operator(name='box_inpainting')
|
6 |
+
class RandomBoxMasker:
|
7 |
+
def __init__(self, height=256, width=256, channels=3, box_size=128, device='cpu'):
|
8 |
+
"""
|
9 |
+
Initialize the ConsistentBoxMasker with random box positioning.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
height (int): Height of the input tensors (default: 256)
|
13 |
+
width (int): Width of the input tensors (default: 256)
|
14 |
+
channels (int): Number of channels in the input tensors (default: 3)
|
15 |
+
box_size (int): Size of the box to mask (default: 128)
|
16 |
+
device (str): Device to create the mask on (default: 'cpu')
|
17 |
+
"""
|
18 |
+
self.height = height
|
19 |
+
self.width = width
|
20 |
+
self.channels = channels
|
21 |
+
self.box_size = min(box_size, height, width) # Ensure box_size doesn't exceed image dimensions
|
22 |
+
self.device = device
|
23 |
+
|
24 |
+
# Create a binary mask for box selection
|
25 |
+
self.mask = torch.ones((1, channels, height, width), device=device)
|
26 |
+
|
27 |
+
# Randomly calculate the top-left corner of the box
|
28 |
+
max_y = height - self.box_size
|
29 |
+
max_x = width - self.box_size
|
30 |
+
|
31 |
+
start_y = torch.randint(0, max_y + 1, (1,)).item()
|
32 |
+
start_x = torch.randint(0, max_x + 1, (1,)).item()
|
33 |
+
|
34 |
+
# Set the box area in the mask to 0
|
35 |
+
self.mask[0, :, start_y:start_y+self.box_size, start_x:start_x+self.box_size] = 0
|
36 |
+
|
37 |
+
def __call__(self, tensor):
|
38 |
+
"""
|
39 |
+
Apply the consistent box masking to the input tensor.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: Tensor with the same shape as input, but with the box area masked out
|
46 |
+
"""
|
47 |
+
b, c, h, w = tensor.shape
|
48 |
+
assert c == self.channels and h == self.height and w == self.width, \
|
49 |
+
f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
|
50 |
+
|
51 |
+
# Move the mask to the same device as the input tensor if necessary
|
52 |
+
if tensor.device != self.mask.device:
|
53 |
+
self.mask = self.mask.to(tensor.device)
|
54 |
+
|
55 |
+
# Apply the mask to the input tensor
|
56 |
+
return tensor * self.mask
|
inference.py
CHANGED
@@ -1,12 +1,55 @@
|
|
1 |
import argparse
|
2 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
def main(args):
|
|
|
|
|
|
|
|
|
6 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
if __name__ == '__main__':
|
12 |
parser = argparse.ArgumentParser()
|
@@ -17,4 +60,6 @@ if __name__ == '__main__':
|
|
17 |
parser.add_argument("operator_config", type=str)
|
18 |
parser.add_argument("noise_config", type=str)
|
19 |
parser.add_argument("--output-dir", default=".", type=str)
|
|
|
|
|
20 |
main(parser.parse_args())
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from cdim.noise import get_noise
|
10 |
+
from cdim.operators import get_operator
|
11 |
+
from cdim.image_utils import save_to_image
|
12 |
+
|
13 |
+
|
14 |
+
def load_image(path):
|
15 |
+
"""
|
16 |
+
Load the image and normalize to [-1, 1]
|
17 |
+
"""
|
18 |
+
original_image = Image.open(path)
|
19 |
+
|
20 |
+
# Resize if needed
|
21 |
+
original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
|
22 |
+
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
|
23 |
+
return (original_image / 127.5 - 1.0).to(torch.float)
|
24 |
+
|
25 |
+
|
26 |
+
def load_yaml(file_path: str) -> dict:
|
27 |
+
with open(file_path) as f:
|
28 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
29 |
+
return config
|
30 |
|
31 |
|
32 |
def main(args):
|
33 |
+
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
|
34 |
+
print(f"Using device {device_str}")
|
35 |
+
device = torch.device(device_str)
|
36 |
+
|
37 |
os.makedirs(args.output_dir, exist_ok=True)
|
38 |
+
original_image = load_image(args.input_image).to(device)
|
39 |
|
40 |
+
# Load the noise function
|
41 |
+
noise_config = load_yaml(args.noise_config)
|
42 |
+
noise_function = get_noise(**noise_config)
|
43 |
+
print(noise_function)
|
44 |
|
45 |
+
# Load the measurement function A
|
46 |
+
operator_config = load_yaml(args.operator_config)
|
47 |
+
operator_config["device"] = device
|
48 |
+
operator = get_operator(**operator_config)
|
49 |
+
print(operator)
|
50 |
+
|
51 |
+
noisy_measurement = noise_function(operator(original_image))
|
52 |
+
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
|
53 |
|
54 |
if __name__ == '__main__':
|
55 |
parser = argparse.ArgumentParser()
|
|
|
60 |
parser.add_argument("operator_config", type=str)
|
61 |
parser.add_argument("noise_config", type=str)
|
62 |
parser.add_argument("--output-dir", default=".", type=str)
|
63 |
+
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
|
64 |
+
|
65 |
main(parser.parse_args())
|
noise_configs/gaussian_noise_config.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: gaussian
|
2 |
+
sigma: 0.05
|
3 |
+
# Important! This noise is assumed to be for images in [0, 1]
|
{cdim/noise_configs β noise_configs}/poisson_noise_config.yaml
RENAMED
File without changes
|
noisy_measurement.png
ADDED
operator_configs/box_inpainting_config.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: box_inpainting
|
2 |
+
box_size: 128
|
3 |
+
height: 256
|
4 |
+
width: 256
|
5 |
+
channels: 3
|
operator_configs/identity_operator_config.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
name: identity
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
numpy==2.1.2
|
2 |
+
Pillow==11.0.0
|