drscotthawley commited on
Commit
0dc3eb6
1 Parent(s): b887586

adding needed files

Browse files
Files changed (2) hide show
  1. pom/v_diffusion.py +168 -0
  2. sample.py +2 -2
pom/v_diffusion.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v-diffusion codes for DDPM inpainting. May not be compatible with k-diffusion.
2
+
3
+ # @SuspectT's inpainting codes, Feb 25 2024
4
+ # shared w/ me over Discord:
5
+ # "that's the v-diffusion inpainting with ddpm
6
+ # optimal settings were around 100 steps for the scheduler
7
+ # (ts refering to timesteps here) and resamples was 4"
8
+
9
+ import torch
10
+ from torch import nn
11
+ from typing import Callable
12
+ from tqdm import trange
13
+ import math
14
+ import sys
15
+
16
+ # from kcrowson/v-diffusion-pytorch
17
+ def t_to_alpha_sigma(t):
18
+ """Returns the scaling factors for the clean image and for the noise, given
19
+ a timestep."""
20
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
21
+
22
+
23
+
24
+ #class DDPM(SamplerBase):
25
+ class DDPM():
26
+
27
+ def __init__(self, model_fn: Callable = None):
28
+ super().__init__()
29
+
30
+ def _step(
31
+ self, model_fn: Callable, x_t: torch.Tensor, step: int,
32
+ t_now: torch.Tensor, t_next: torch.Tensor,
33
+ callback: Callable, model_args, **sampler_args ) -> torch.Tensor:
34
+
35
+ alpha_now, sigma_now = t_to_alpha_sigma(t_now) # Get alpha / sigma for current timestep.
36
+ alpha_next, sigma_next = t_to_alpha_sigma(t_next) # Get alpha / sigma for next timestep.
37
+
38
+ v_t = model_fn(x_t, t_now.expand(x_t.shape[0]), **model_args) # Expand t to match batch_size which corresponds to x_t.shape[0]
39
+
40
+ eps_t = x_t * sigma_now + v_t * alpha_now
41
+ pred_t = x_t * alpha_now - v_t * sigma_now
42
+
43
+ if callback is not None:
44
+ callback({'step': step, 'x': x_t, 't': t_now, 'pred': pred_t, 'eps': eps_t})
45
+
46
+ return (pred_t * alpha_next + eps_t * sigma_next)
47
+
48
+ def _sample( self, model_fn: Callable, x_t: torch.Tensor, ts: torch.Tensor,
49
+ callback: Callable, model_args, **sampler_args ) -> torch.Tensor:
50
+
51
+ print("Using DDPM Sampler.")
52
+ steps = ts.size(0)
53
+
54
+ use_tqdm = sampler_args.get('use_tqdm')
55
+ use_range = trange if (use_tqdm if (use_tqdm != None) else False) else range
56
+
57
+ for step in use_range(steps - 1):
58
+ x_t = self._step( model_fn, x_t, step, ts[step], ts[step + 1],
59
+ lambda kwargs: callback(**dict(kwargs, steps=steps)) if(callback != None) else None,
60
+ model_args )
61
+
62
+ return x_t
63
+
64
+
65
+ def _inpaint(self,
66
+ model_fn: Callable, audio_source: torch.Tensor, mask: torch.Tensor,
67
+ ts: torch.Tensor, resamples: int, callback: Callable, model_args, **sampler_args
68
+ ) -> torch.Tensor:
69
+ steps = ts.size(0)
70
+ batch_size = audio_source.size(0)
71
+ alphas, sigmas = t_to_alpha_sigma(ts)
72
+
73
+ # SHH: rescale audio_source to zero mean and unit variance
74
+ audio_source = (audio_source - audio_source.mean()) / audio_source.std()
75
+
76
+ x_t = audio_source
77
+
78
+ use_tqdm = sampler_args.get('use_tqdm')
79
+ use_range = trange if (use_tqdm if (use_tqdm != None) else False) else range
80
+
81
+ for step in use_range(steps - 1):
82
+ print("step, audio_source.min, audio_source.max, alphas[step], sigmas[step] = ", step, audio_source.min(), audio_source.max(), alphas[step], sigmas[step])
83
+ audio_source_noised = audio_source * alphas[step] + torch.randn_like(audio_source) * sigmas[step]
84
+ print("step, audio_source_noised.min, audio_source_noised.max = ", step, audio_source_noised.min(), audio_source_noised.max())
85
+ sigma_dt = torch.sqrt(sigmas[step] ** 2 - sigmas[step + 1] ** 2)
86
+
87
+ for re in range(resamples):
88
+
89
+ #x_t = audio_source_noised * mask + x_t * ~mask
90
+ x_t = audio_source_noised * mask + x_t * (1.0-mask)
91
+
92
+ # from ImageTransformerDenoiserModelV2:
93
+ # def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None):
94
+ #v_t = model_fn(x_t, ts[step].expand(batch_size), **model_args)
95
+ print("step, re, x_t.min, x_t.max , sigmas[step]= ", step, re, x_t.min(), x_t.max(), sigmas[step])
96
+ v_t = model_fn(x_t, sigmas[step].expand(batch_size), aug_cond=None, class_cond=None, mapping_cond=None)
97
+ print("step, re, v_t.min, v_t.max = ", step, re, v_t.min(), v_t.max())
98
+ if v_t.isnan().any():
99
+ print("v_t has NaNs.")
100
+ sys.exit(0)
101
+
102
+ eps_t = x_t * sigmas[step] + v_t * alphas[step]
103
+ pred_t = x_t * alphas[step] - v_t * sigmas[step]
104
+
105
+ if callback is not None:
106
+ callback({'steps': steps, 'step': step, 'x': x_t, 't': ts[step], 'pred': pred_t, 'eps': eps_t, 'res': re})
107
+
108
+ if(re < resamples - 1):
109
+ x_t = pred_t * alphas[step] + eps_t * sigmas[step + 1] + sigma_dt * torch.randn_like(x_t)
110
+ else:
111
+ x_t = pred_t * alphas[step + 1] + eps_t * sigmas[step + 1]
112
+
113
+ print("step, re, v_t.min, v_t.max, x_t.min, x_t.max = ", step, re, v_t.min(), v_t.max(), x_t.min(), x_t.max())
114
+
115
+ #sys.exit(0)
116
+
117
+ return (audio_source * mask + x_t * (1.0-mask))
118
+
119
+
120
+ def alpha_sigma_to_t(alpha, sigma):
121
+ """Returns a timestep, given the scaling factors for the clean image and for
122
+ the noise."""
123
+ return torch.atan2(sigma, alpha) / math.pi * 2
124
+
125
+ def log_snr_to_alpha_sigma(log_snr):
126
+ """Returns the scaling factors for the clean image and for the noise, given
127
+ the log SNR for a timestep."""
128
+ return log_snr.sigmoid().sqrt(), log_snr.neg().sigmoid().sqrt()
129
+
130
+ def get_ddpm_schedule(ddpm_t):
131
+ """Returns timesteps for the noise schedule from the DDPM paper."""
132
+ log_snr = -torch.special.expm1(1e-4 + 10 * ddpm_t**2).log()
133
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr)
134
+ return alpha_sigma_to_t(alpha, sigma)
135
+
136
+
137
+ #class LogSchedule(SchedulerBase):
138
+ class LogSchedule():
139
+ def __init__(self, device:torch.device = None):
140
+ super().__init__(device)
141
+
142
+ def create(self, steps: int, first: float = 1, last: float = 0, device: torch.device = None, scheduler_args = {'min_log_snr': -10, 'max_log_snr': 10}) -> torch.Tensor:
143
+ ramp = torch.linspace(first, last, steps, device = device if (device != None) else self.device)
144
+ min_log_snr = scheduler_args.get('min_log_snr')
145
+ max_log_snr = scheduler_args.get('max_log_snr')
146
+ return self.get_log_schedule(
147
+ ramp,
148
+ min_log_snr if min_log_snr!=None else -10,
149
+ max_log_snr if max_log_snr!=None else 10,
150
+ )
151
+
152
+ def get_log_schedule(self, t, min_log_snr=-10, max_log_snr=10):
153
+ log_snr = t * (min_log_snr - max_log_snr) + max_log_snr
154
+ alpha = log_snr.sigmoid().sqrt()
155
+ sigma = log_snr.neg().sigmoid().sqrt()
156
+ return torch.atan2(sigma, alpha) / math.pi * 2 # this returns a timestep?
157
+
158
+
159
+ #class CrashSchedule(SchedulerBase):
160
+ class CrashSchedule():
161
+ def __init__(self, device:torch.device = None):
162
+ super().__init__(device)
163
+
164
+ def create(self, steps: int, first: float = 1, last: float = 0, device: torch.device = None, scheduler_args = None) -> torch.Tensor:
165
+ ramp = torch.linspace(first, last, steps, device = device if (device != None) else self.device)
166
+ sigma = torch.sin(ramp * math.pi / 2) ** 2
167
+ alpha = (1 - sigma**2) ** 0.5
168
+ return torch.atan2(sigma, alpha) / math.pi * 2 # this returns a timestep?
sample.py CHANGED
@@ -16,9 +16,9 @@ from torchvision import transforms
16
 
17
  import k_diffusion as K
18
 
19
- from control_toys.v_diffusion import DDPM, LogSchedule, CrashSchedule
20
  #CHORD_BORDER = 8 # chord border size in pixels
21
- from control_toys.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
22
 
23
 
24
  # ---- my mangled sampler that includes repaint
 
16
 
17
  import k_diffusion as K
18
 
19
+ from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule
20
  #CHORD_BORDER = 8 # chord border size in pixels
21
+ from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
22
 
23
 
24
  # ---- my mangled sampler that includes repaint