ysharma HF staff commited on
Commit
d432951
1 Parent(s): de33098
Files changed (1) hide show
  1. ddim.py +203 -0
ddim.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
+ alphas_cumprod = self.model.alphas_cumprod
28
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30
+
31
+ self.register_buffer('betas', to_torch(self.model.betas))
32
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34
+
35
+ # calculations for diffusion q(x_t | x_{t-1}) and others
36
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41
+
42
+ # ddim sampling parameters
43
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
+ ddim_timesteps=self.ddim_timesteps,
45
+ eta=ddim_eta,verbose=verbose)
46
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
47
+ self.register_buffer('ddim_alphas', ddim_alphas)
48
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
+
55
+ @torch.no_grad()
56
+ def sample(self,
57
+ S,
58
+ batch_size,
59
+ shape,
60
+ conditioning=None,
61
+ callback=None,
62
+ normals_sequence=None,
63
+ img_callback=None,
64
+ quantize_x0=False,
65
+ eta=0.,
66
+ mask=None,
67
+ x0=None,
68
+ temperature=1.,
69
+ noise_dropout=0.,
70
+ score_corrector=None,
71
+ corrector_kwargs=None,
72
+ verbose=True,
73
+ x_T=None,
74
+ log_every_t=100,
75
+ unconditional_guidance_scale=1.,
76
+ unconditional_conditioning=None,
77
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83
+ if cbs != batch_size:
84
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85
+ else:
86
+ if conditioning.shape[0] != batch_size:
87
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88
+
89
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90
+ # sampling
91
+ C, H, W = shape
92
+ size = (batch_size, C, H, W)
93
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94
+
95
+ samples, intermediates = self.ddim_sampling(conditioning, size,
96
+ callback=callback,
97
+ img_callback=img_callback,
98
+ quantize_denoised=quantize_x0,
99
+ mask=mask, x0=x0,
100
+ ddim_use_original_steps=False,
101
+ noise_dropout=noise_dropout,
102
+ temperature=temperature,
103
+ score_corrector=score_corrector,
104
+ corrector_kwargs=corrector_kwargs,
105
+ x_T=x_T,
106
+ log_every_t=log_every_t,
107
+ unconditional_guidance_scale=unconditional_guidance_scale,
108
+ unconditional_conditioning=unconditional_conditioning,
109
+ )
110
+ return samples, intermediates
111
+
112
+ @torch.no_grad()
113
+ def ddim_sampling(self, cond, shape,
114
+ x_T=None, ddim_use_original_steps=False,
115
+ callback=None, timesteps=None, quantize_denoised=False,
116
+ mask=None, x0=None, img_callback=None, log_every_t=100,
117
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
119
+ device = self.model.betas.device
120
+ b = shape[0]
121
+ if x_T is None:
122
+ img = torch.randn(shape, device=device)
123
+ else:
124
+ img = x_T
125
+
126
+ if timesteps is None:
127
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128
+ elif timesteps is not None and not ddim_use_original_steps:
129
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130
+ timesteps = self.ddim_timesteps[:subset_end]
131
+
132
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
133
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
136
+
137
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138
+
139
+ for i, step in enumerate(iterator):
140
+ index = total_steps - i - 1
141
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
142
+
143
+ if mask is not None:
144
+ assert x0 is not None
145
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146
+ img = img_orig * mask + (1. - mask) * img
147
+
148
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149
+ quantize_denoised=quantize_denoised, temperature=temperature,
150
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
151
+ corrector_kwargs=corrector_kwargs,
152
+ unconditional_guidance_scale=unconditional_guidance_scale,
153
+ unconditional_conditioning=unconditional_conditioning)
154
+ img, pred_x0 = outs
155
+ if callback: callback(i)
156
+ if img_callback: img_callback(pred_x0, i)
157
+
158
+ if index % log_every_t == 0 or index == total_steps - 1:
159
+ intermediates['x_inter'].append(img)
160
+ intermediates['pred_x0'].append(pred_x0)
161
+
162
+ return img, intermediates
163
+
164
+ @torch.no_grad()
165
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
168
+ b, *_, device = *x.shape, x.device
169
+
170
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171
+ e_t = self.model.apply_model(x, t, c)
172
+ else:
173
+ x_in = torch.cat([x] * 2)
174
+ t_in = torch.cat([t] * 2)
175
+ c_in = torch.cat([unconditional_conditioning, c])
176
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178
+
179
+ if score_corrector is not None:
180
+ assert self.model.parameterization == "eps"
181
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182
+
183
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187
+ # select parameters corresponding to the currently considered timestep
188
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192
+
193
+ # current prediction for x_0
194
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195
+ if quantize_denoised:
196
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197
+ # direction pointing to x_t
198
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200
+ if noise_dropout > 0.:
201
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203
+ return x_prev, pred_x0