Spaces:
Sleeping
Sleeping
add dps
Browse files
diffusion-posterior-sampling/Dockerfile
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.3.1-devel-ubuntu20.04
|
2 |
+
|
3 |
+
ENV TZ=Asiz/Seoul
|
4 |
+
ENV TERM=xterm-256color
|
5 |
+
|
6 |
+
RUN ln -fs /usr/share/zoneinfo/Asia/Seoul /etc/localtime
|
7 |
+
|
8 |
+
#### 0. Install python and pip
|
9 |
+
RUN apt-get -y update && apt-get install -y git wget curl
|
10 |
+
RUN apt-get update
|
11 |
+
RUN apt-get upgrade python3 -y
|
12 |
+
RUN apt-get install python3-pip -y
|
13 |
+
RUN alias python='python3'
|
14 |
+
|
15 |
+
#### 1. Install Pytorch
|
16 |
+
RUN pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
17 |
+
|
18 |
+
#### 2. Install other dependencies
|
19 |
+
WORKDIR /usr/app
|
20 |
+
COPY . ./
|
21 |
+
RUN pip install -r ./requirements.txt
|
22 |
+
|
23 |
+
#### 3. Clone external codes
|
24 |
+
RUN git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse
|
25 |
+
RUN git clone https://github.com/LeviBorodenko/motionblur motionblur
|
26 |
+
|
27 |
+
#### 4. change user
|
28 |
+
RUN useradd docker_user -u 1000 -m
|
diffusion-posterior-sampling/README.md
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusion Posterior Sampling for General Noisy Inverse Problems (ICLR 2023 spotlight)
|
2 |
+
|
3 |
+
![result-gif1](./figures/motion_blur.gif)
|
4 |
+
![result-git2](./figures/super_resolution.gif)
|
5 |
+
<!-- See more results in the [project-page](https://jeongsol-kim.github.io/dps-project-page) -->
|
6 |
+
|
7 |
+
## Abstract
|
8 |
+
In this work, we extend diffusion solvers to efficiently handle general noisy (non)linear inverse problems via the approximation of the posterior sampling. Interestingly, the resulting posterior sampling scheme is a blended version of the diffusion sampling with the manifold constrained gradient without strict measurement consistency projection step, yielding more desirable generative path in noisy settings compared to the previous studies.
|
9 |
+
|
10 |
+
![cover-img](./figures/cover.jpg)
|
11 |
+
|
12 |
+
|
13 |
+
## Prerequisites
|
14 |
+
- python 3.8
|
15 |
+
|
16 |
+
- pytorch 1.11.0
|
17 |
+
|
18 |
+
- CUDA 11.3.1
|
19 |
+
|
20 |
+
- nvidia-docker (if you use GPU in docker container)
|
21 |
+
|
22 |
+
It is okay to use lower version of CUDA with proper pytorch version.
|
23 |
+
|
24 |
+
Ex) CUDA 10.2 with pytorch 1.7.0
|
25 |
+
|
26 |
+
<br />
|
27 |
+
|
28 |
+
## Getting started
|
29 |
+
|
30 |
+
### 1) Clone the repository
|
31 |
+
|
32 |
+
```
|
33 |
+
git clone https://github.com/DPS2022/diffusion-posterior-sampling
|
34 |
+
|
35 |
+
cd diffusion-posterior-sampling
|
36 |
+
```
|
37 |
+
|
38 |
+
<br />
|
39 |
+
|
40 |
+
### 2) Download pretrained checkpoint
|
41 |
+
From the [link](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh?usp=sharing), download the checkpoint "ffhq_10m.pt" and paste it to ./models/
|
42 |
+
```
|
43 |
+
mkdir models
|
44 |
+
mv {DOWNLOAD_DIR}/ffqh_10m.pt ./models/
|
45 |
+
```
|
46 |
+
{DOWNLOAD_DIR} is the directory that you downloaded checkpoint to.
|
47 |
+
|
48 |
+
:speaker: Checkpoint for imagenet is uploaded.
|
49 |
+
|
50 |
+
<br />
|
51 |
+
|
52 |
+
|
53 |
+
### 3) Set environment
|
54 |
+
### [Option 1] Local environment setting
|
55 |
+
|
56 |
+
We use the external codes for motion-blurring and non-linear deblurring.
|
57 |
+
|
58 |
+
```
|
59 |
+
git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse
|
60 |
+
|
61 |
+
git clone https://github.com/LeviBorodenko/motionblur motionblur
|
62 |
+
```
|
63 |
+
|
64 |
+
Install dependencies
|
65 |
+
|
66 |
+
```
|
67 |
+
conda create -n DPS python=3.8
|
68 |
+
|
69 |
+
conda activate DPS
|
70 |
+
|
71 |
+
pip install -r requirements.txt
|
72 |
+
|
73 |
+
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
74 |
+
```
|
75 |
+
|
76 |
+
<br />
|
77 |
+
|
78 |
+
### [Option 2] Build Docker image
|
79 |
+
|
80 |
+
Install docker engine, GPU driver and proper cuda before running the following commands.
|
81 |
+
|
82 |
+
Dockerfile already contains command to clone external codes. You don't have to clone them again.
|
83 |
+
|
84 |
+
--gpus=all is required to use local GPU device (Docker >= 19.03)
|
85 |
+
|
86 |
+
```
|
87 |
+
docker build -t dps-docker:latest .
|
88 |
+
|
89 |
+
docker run -it --rm --gpus=all dps-docker
|
90 |
+
```
|
91 |
+
|
92 |
+
<br />
|
93 |
+
|
94 |
+
### 4) Inference
|
95 |
+
|
96 |
+
```
|
97 |
+
python3 sample_condition.py \
|
98 |
+
--model_config=configs/model_config.yaml \
|
99 |
+
--diffusion_config=configs/diffusion_config.yaml \
|
100 |
+
--task_config={TASK-CONFIG};
|
101 |
+
```
|
102 |
+
|
103 |
+
|
104 |
+
:speaker: For imagenet, use configs/imagenet_model_config.yaml
|
105 |
+
|
106 |
+
<br />
|
107 |
+
|
108 |
+
## Possible task configurations
|
109 |
+
|
110 |
+
```
|
111 |
+
# Linear inverse problems
|
112 |
+
- configs/super_resolution_config.yaml
|
113 |
+
- configs/gaussian_deblur_config.yaml
|
114 |
+
- configs/motion_deblur_config.yaml
|
115 |
+
- configs/inpainting_config.yaml
|
116 |
+
|
117 |
+
# Non-linear inverse problems
|
118 |
+
- configs/nonlinear_deblur_config.yaml
|
119 |
+
- configs/phase_retrieval_config.yaml
|
120 |
+
```
|
121 |
+
|
122 |
+
### Structure of task configurations
|
123 |
+
You need to write your data directory at data.root. Default is ./data/samples which contains three sample images from FFHQ validation set.
|
124 |
+
|
125 |
+
```
|
126 |
+
conditioning:
|
127 |
+
method: # check candidates in guided_diffusion/condition_methods.py
|
128 |
+
params:
|
129 |
+
scale: 0.5
|
130 |
+
|
131 |
+
data:
|
132 |
+
name: ffhq
|
133 |
+
root: ./data/samples/
|
134 |
+
|
135 |
+
measurement:
|
136 |
+
operator:
|
137 |
+
name: # check candidates in guided_diffusion/measurements.py
|
138 |
+
|
139 |
+
noise:
|
140 |
+
name: # gaussian or poisson
|
141 |
+
sigma: # if you use name: gaussian, set this.
|
142 |
+
(rate:) # if you use name: poisson, set this.
|
143 |
+
```
|
144 |
+
|
145 |
+
## Citation
|
146 |
+
If you find our work interesting, please consider citing
|
147 |
+
|
148 |
+
```
|
149 |
+
@inproceedings{
|
150 |
+
chung2023diffusion,
|
151 |
+
title={Diffusion Posterior Sampling for General Noisy Inverse Problems},
|
152 |
+
author={Hyungjin Chung and Jeongsol Kim and Michael Thompson Mccann and Marc Louis Klasky and Jong Chul Ye},
|
153 |
+
booktitle={The Eleventh International Conference on Learning Representations },
|
154 |
+
year={2023},
|
155 |
+
url={https://openreview.net/forum?id=OnD9zGAGT0k}
|
156 |
+
}
|
157 |
+
```
|
158 |
+
|
diffusion-posterior-sampling/requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
certifi==2022.9.14
|
2 |
+
charset-normalizer==2.1.1
|
3 |
+
contourpy==1.0.5
|
4 |
+
cycler==0.11.0
|
5 |
+
fonttools==4.37.2
|
6 |
+
idna==3.4
|
7 |
+
kiwisolver==1.4.4
|
8 |
+
matplotlib==3.6.0
|
9 |
+
numpy==1.23.3
|
10 |
+
packaging==21.3
|
11 |
+
Pillow==9.2.0
|
12 |
+
pyparsing==3.0.9
|
13 |
+
python-dateutil==2.8.2
|
14 |
+
PyYAML==6.0
|
15 |
+
requests==2.28.1
|
16 |
+
scipy==1.9.1
|
17 |
+
six==1.16.0
|
18 |
+
tqdm==4.64.1
|
19 |
+
typing-extensions==4.3.0
|
20 |
+
urllib3==1.26.12
|
diffusion-posterior-sampling/sample_condition.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from guided_diffusion.condition_methods import get_conditioning_method
|
11 |
+
from guided_diffusion.measurements import get_noise, get_operator
|
12 |
+
from guided_diffusion.unet import create_model
|
13 |
+
from guided_diffusion.gaussian_diffusion import create_sampler
|
14 |
+
from data.dataloader import get_dataset, get_dataloader
|
15 |
+
from util.img_utils import clear_color, mask_generator
|
16 |
+
from util.logger import get_logger
|
17 |
+
|
18 |
+
|
19 |
+
def load_yaml(file_path: str) -> dict:
|
20 |
+
with open(file_path) as f:
|
21 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
22 |
+
return config
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument('--model_config', type=str)
|
28 |
+
parser.add_argument('--diffusion_config', type=str)
|
29 |
+
parser.add_argument('--task_config', type=str)
|
30 |
+
parser.add_argument('--gpu', type=int, default=0)
|
31 |
+
parser.add_argument('--save_dir', type=str, default='./results')
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
# logger
|
35 |
+
logger = get_logger()
|
36 |
+
|
37 |
+
# Device setting
|
38 |
+
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
|
39 |
+
logger.info(f"Device set to {device_str}.")
|
40 |
+
device = torch.device(device_str)
|
41 |
+
|
42 |
+
# Load configurations
|
43 |
+
model_config = load_yaml(args.model_config)
|
44 |
+
diffusion_config = load_yaml(args.diffusion_config)
|
45 |
+
task_config = load_yaml(args.task_config)
|
46 |
+
|
47 |
+
#assert model_config['learn_sigma'] == diffusion_config['learn_sigma'], \
|
48 |
+
#"learn_sigma must be the same for model and diffusion configuartion."
|
49 |
+
|
50 |
+
# Load model
|
51 |
+
model = create_model(**model_config)
|
52 |
+
model = model.to(device)
|
53 |
+
model.eval()
|
54 |
+
|
55 |
+
# Prepare Operator and noise
|
56 |
+
measure_config = task_config['measurement']
|
57 |
+
operator = get_operator(device=device, **measure_config['operator'])
|
58 |
+
noiser = get_noise(**measure_config['noise'])
|
59 |
+
logger.info(f"Operation: {measure_config['operator']['name']} / Noise: {measure_config['noise']['name']}")
|
60 |
+
|
61 |
+
# Prepare conditioning method
|
62 |
+
cond_config = task_config['conditioning']
|
63 |
+
cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
|
64 |
+
measurement_cond_fn = cond_method.conditioning
|
65 |
+
logger.info(f"Conditioning method : {task_config['conditioning']['method']}")
|
66 |
+
|
67 |
+
# Load diffusion sampler
|
68 |
+
sampler = create_sampler(**diffusion_config)
|
69 |
+
sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn)
|
70 |
+
|
71 |
+
# Working directory
|
72 |
+
out_path = os.path.join(args.save_dir, measure_config['operator']['name'])
|
73 |
+
os.makedirs(out_path, exist_ok=True)
|
74 |
+
for img_dir in ['input', 'recon', 'progress', 'label']:
|
75 |
+
os.makedirs(os.path.join(out_path, img_dir), exist_ok=True)
|
76 |
+
|
77 |
+
# Prepare dataloader
|
78 |
+
data_config = task_config['data']
|
79 |
+
transform = transforms.Compose([transforms.ToTensor(),
|
80 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
81 |
+
dataset = get_dataset(**data_config, transforms=transform)
|
82 |
+
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)
|
83 |
+
|
84 |
+
# Exception) In case of inpainting, we need to generate a mask
|
85 |
+
if measure_config['operator']['name'] == 'inpainting':
|
86 |
+
mask_gen = mask_generator(
|
87 |
+
**measure_config['mask_opt']
|
88 |
+
)
|
89 |
+
|
90 |
+
# Do Inference
|
91 |
+
for i, ref_img in enumerate(loader):
|
92 |
+
logger.info(f"Inference for image {i}")
|
93 |
+
fname = str(i).zfill(5) + '.png'
|
94 |
+
ref_img = ref_img.to(device)
|
95 |
+
|
96 |
+
# Exception) In case of inpainging,
|
97 |
+
if measure_config['operator'] ['name'] == 'inpainting':
|
98 |
+
mask = mask_gen(ref_img)
|
99 |
+
mask = mask[:, 0, :, :].unsqueeze(dim=0)
|
100 |
+
measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
|
101 |
+
sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)
|
102 |
+
|
103 |
+
# Forward measurement model (Ax + n)
|
104 |
+
y = operator.forward(ref_img, mask=mask)
|
105 |
+
y_n = noiser(y)
|
106 |
+
|
107 |
+
else:
|
108 |
+
# Forward measurement model (Ax + n)
|
109 |
+
y = operator.forward(ref_img)
|
110 |
+
y_n = noiser(y)
|
111 |
+
|
112 |
+
# Sampling
|
113 |
+
x_start = torch.randn(ref_img.shape, device=device).requires_grad_()
|
114 |
+
sample = sample_fn(x_start=x_start, measurement=y_n, record=True, save_root=out_path)
|
115 |
+
|
116 |
+
plt.imsave(os.path.join(out_path, 'input', fname), clear_color(y_n))
|
117 |
+
plt.imsave(os.path.join(out_path, 'label', fname), clear_color(ref_img))
|
118 |
+
plt.imsave(os.path.join(out_path, 'recon', fname), clear_color(sample))
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
main()
|