File size: 7,486 Bytes
0dc3eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# v-diffusion codes for DDPM inpainting.  May not be compatible with k-diffusion.

# @SuspectT's inpainting codes, Feb 25 2024
# shared w/ me over Discord:
# "that's the v-diffusion inpainting with ddpm
# optimal settings were around 100 steps for the scheduler 
# (ts refering to timesteps here) and resamples was 4"

import torch
from torch import nn
from typing import Callable
from tqdm import trange
import math
import sys

# from kcrowson/v-diffusion-pytorch
def t_to_alpha_sigma(t):
    """Returns the scaling factors for the clean image and for the noise, given
    a timestep."""
    return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)



#class DDPM(SamplerBase):
class DDPM():
    
    def __init__(self, model_fn: Callable = None):
        super().__init__()
    
    def _step(
        self, model_fn: Callable, x_t: torch.Tensor, step: int,
        t_now: torch.Tensor, t_next: torch.Tensor,
        callback: Callable, model_args,  **sampler_args ) -> torch.Tensor:
        
        alpha_now, sigma_now = t_to_alpha_sigma(t_now) # Get alpha / sigma for current timestep.
        alpha_next, sigma_next = t_to_alpha_sigma(t_next) # Get alpha / sigma for next timestep.
        
        v_t = model_fn(x_t, t_now.expand(x_t.shape[0]), **model_args) # Expand t to match batch_size which corresponds to x_t.shape[0]
        
        eps_t = x_t * sigma_now + v_t * alpha_now
        pred_t = x_t * alpha_now - v_t * sigma_now
        
        if callback is not None:
            callback({'step': step, 'x': x_t, 't': t_now, 'pred': pred_t, 'eps': eps_t})
            
        return (pred_t * alpha_next + eps_t * sigma_next)
    
    def _sample( self, model_fn: Callable, x_t: torch.Tensor, ts: torch.Tensor,
        callback: Callable, model_args, **sampler_args ) -> torch.Tensor:
        
        print("Using DDPM Sampler.")
        steps = ts.size(0)
        
        use_tqdm = sampler_args.get('use_tqdm')
        use_range = trange if (use_tqdm if (use_tqdm != None) else False) else range
        
        for step in use_range(steps - 1):
            x_t = self._step( model_fn, x_t, step, ts[step], ts[step + 1],
                lambda kwargs: callback(**dict(kwargs, steps=steps)) if(callback != None) else None,
                model_args  )
            
        return x_t
    
      
    def _inpaint(self,
        model_fn: Callable, audio_source: torch.Tensor, mask: torch.Tensor,
        ts: torch.Tensor,   resamples: int, callback: Callable, model_args, **sampler_args
        ) -> torch.Tensor:
        steps = ts.size(0)
        batch_size = audio_source.size(0)
        alphas, sigmas = t_to_alpha_sigma(ts)

        # SHH: rescale audio_source to zero mean and unit variance
        audio_source = (audio_source - audio_source.mean()) / audio_source.std()

        x_t = audio_source
        
        use_tqdm = sampler_args.get('use_tqdm')
        use_range = trange if (use_tqdm if (use_tqdm != None) else False) else range
        
        for step in use_range(steps - 1):
            print("step, audio_source.min, audio_source.max, alphas[step], sigmas[step] = ", step, audio_source.min(), audio_source.max(), alphas[step], sigmas[step])
            audio_source_noised = audio_source * alphas[step] + torch.randn_like(audio_source) * sigmas[step]
            print("step, audio_source_noised.min, audio_source_noised.max = ", step, audio_source_noised.min(), audio_source_noised.max())
            sigma_dt = torch.sqrt(sigmas[step] ** 2 - sigmas[step + 1] ** 2)
            
            for re in range(resamples):
                
                #x_t = audio_source_noised * mask + x_t * ~mask 
                x_t = audio_source_noised * mask + x_t * (1.0-mask) 
         
                # from ImageTransformerDenoiserModelV2:
                #        def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None):
                #v_t = model_fn(x_t, ts[step].expand(batch_size), **model_args)
                print("step, re, x_t.min, x_t.max , sigmas[step]= ", step, re, x_t.min(), x_t.max(), sigmas[step])
                v_t = model_fn(x_t, sigmas[step].expand(batch_size), aug_cond=None, class_cond=None, mapping_cond=None)
                print("step, re, v_t.min, v_t.max = ", step, re, v_t.min(), v_t.max())
                if v_t.isnan().any():
                    print("v_t has NaNs.")
                    sys.exit(0)
                
                eps_t = x_t * sigmas[step] + v_t * alphas[step]
                pred_t = x_t * alphas[step] - v_t * sigmas[step]
                
                if callback is not None:
                    callback({'steps': steps, 'step': step, 'x': x_t, 't': ts[step], 'pred': pred_t, 'eps': eps_t, 'res': re})
                
                if(re < resamples - 1):
                    x_t = pred_t * alphas[step] + eps_t * sigmas[step + 1] + sigma_dt * torch.randn_like(x_t)
                else:
                    x_t = pred_t * alphas[step + 1] + eps_t * sigmas[step + 1]

                print("step, re, v_t.min, v_t.max, x_t.min, x_t.max = ", step, re, v_t.min(), v_t.max(), x_t.min(), x_t.max())

                #sys.exit(0)

        return (audio_source * mask + x_t * (1.0-mask))


def alpha_sigma_to_t(alpha, sigma):
    """Returns a timestep, given the scaling factors for the clean image and for
    the noise."""
    return torch.atan2(sigma, alpha) / math.pi * 2  

def log_snr_to_alpha_sigma(log_snr):
    """Returns the scaling factors for the clean image and for the noise, given
    the log SNR for a timestep."""
    return log_snr.sigmoid().sqrt(), log_snr.neg().sigmoid().sqrt()

def get_ddpm_schedule(ddpm_t):
    """Returns timesteps for the noise schedule from the DDPM paper."""
    log_snr = -torch.special.expm1(1e-4 + 10 * ddpm_t**2).log()
    alpha, sigma = log_snr_to_alpha_sigma(log_snr)
    return alpha_sigma_to_t(alpha, sigma)


#class LogSchedule(SchedulerBase):
class LogSchedule():
    def __init__(self, device:torch.device = None):
        super().__init__(device)
        
    def create(self, steps: int, first: float = 1, last: float = 0, device: torch.device = None, scheduler_args = {'min_log_snr': -10, 'max_log_snr': 10}) -> torch.Tensor:
        ramp = torch.linspace(first, last, steps, device = device if (device != None) else self.device)
        min_log_snr = scheduler_args.get('min_log_snr')
        max_log_snr = scheduler_args.get('max_log_snr')
        return self.get_log_schedule(
            ramp,
            min_log_snr if min_log_snr!=None else -10,
            max_log_snr if max_log_snr!=None else 10,
        )
        
    def get_log_schedule(self, t, min_log_snr=-10, max_log_snr=10):
        log_snr = t * (min_log_snr - max_log_snr) + max_log_snr
        alpha = log_snr.sigmoid().sqrt()
        sigma = log_snr.neg().sigmoid().sqrt()
        return torch.atan2(sigma, alpha) / math.pi * 2  # this returns a timestep?


#class CrashSchedule(SchedulerBase):
class CrashSchedule():
    def __init__(self, device:torch.device = None):
        super().__init__(device)
    
    def create(self, steps: int, first: float = 1, last: float = 0, device: torch.device = None, scheduler_args = None) -> torch.Tensor:
        ramp = torch.linspace(first, last, steps, device = device if (device != None) else self.device)
        sigma = torch.sin(ramp * math.pi / 2) ** 2
        alpha = (1 - sigma**2) ** 0.5
        return torch.atan2(sigma, alpha) / math.pi * 2 # this returns a timestep?