raoulduke420's picture
Upload folder using huggingface_hub
ef9fd1f
import torch
import ldm.models.diffusion.ddpm
from modules import shared
class Scheduler:
""" Proportional Noise Step Scheduler"""
def __init__(self, cycle_step=128, repeat=True):
self.disabled = True
self.cycle_step = int(cycle_step)
self.repeat = repeat
self.run_assertion()
def __call__(self, value, step):
if self.disabled:
return value
if self.repeat:
step %= self.cycle_step
return max(1, int(value * step / self.cycle_step))
else:
return value if step >= self.cycle_step else max(1, int(value * step / self.cycle_step))
def run_assertion(self):
assert type(self.cycle_step) is int
assert type(self.repeat) is bool
assert not self.repeat or self.cycle_step > 0
def set(self, cycle_step=-1, repeat=-1, disabled=True):
self.disabled = disabled
if cycle_step >= 0:
self.cycle_step = int(cycle_step)
if repeat != -1:
self.repeat = repeat
self.run_assertion()
training_scheduler = Scheduler(cycle_step=-1, repeat=False)
def get_current(value, step=None):
if step is None:
if hasattr(shared, 'accessible_hypernetwork'):
hypernetwork = shared.accessible_hypernetwork
else:
return value
if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None:
return training_scheduler(value, hypernetwork.step)
return value
return max(1, training_scheduler(value, step))
def set_scheduler(cycle_step, repeat, enabled=False):
global training_scheduler
training_scheduler.set(cycle_step, repeat, not enabled)
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, get_current(self.num_timesteps), (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
ldm.models.diffusion.ddpm.LatentDiffusion.forward = forward