pan-yl's picture
update file
2a00960
raw
history blame
8.31 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from collections import OrderedDict
import torch
from tqdm import trange
from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
NOISE_SCHEDULERS)
from scepter.modules.utils.config import Config, dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
@DIFFUSIONS.register_class()
class ACEDiffusion(object):
para_dict = {
'NOISE_SCHEDULER': {},
'SAMPLER_SCHEDULER': {},
'MIN_SNR_GAMMA': {
'value': None,
'description': 'The minimum SNR gamma value for the loss function.'
},
'PREDICTION_TYPE': {
'value': 'eps',
'description':
'The type of prediction to use for the loss function.'
}
}
def __init__(self, cfg, logger=None):
super(ACEDiffusion, self).__init__()
self.logger = logger
self.cfg = cfg
self.init_params()
def init_params(self):
self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
logger=self.logger)
self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
logger=self.logger)
self.num_timesteps = self.noise_scheduler.num_timesteps
if self.cfg.have('WORK_DIR') and we.rank == 0:
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
'noise_schedule.png')
with FS.put_to(schedule_visualization) as local_path:
self.noise_scheduler.plot_noise_sampling_map(local_path)
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
'sampler_schedule.png')
with FS.put_to(schedule_visualization) as local_path:
self.sampler_scheduler.plot_noise_sampling_map(local_path)
def sample(self,
noise,
model,
model_kwargs={},
steps=20,
sampler=None,
use_dynamic_cfg=False,
guide_scale=None,
guide_rescale=None,
show_progress=False,
return_intermediate=None,
intermediate_callback=None,
**kwargs):
assert isinstance(steps, (int, torch.LongTensor))
assert return_intermediate in (None, 'x0', 'xt')
assert isinstance(sampler, (str, dict, Config))
intermediates = []
def callback_fn(x_t, t, sigma=None, alpha=None):
timestamp = t
t = t.repeat(len(x_t)).round().long().to(x_t.device)
sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))
if guide_scale is None or guide_scale == 1.0:
out = model(x=x_t, t=t, **model_kwargs)
else:
if use_dynamic_cfg:
guidance_scale = 1 + guide_scale * (
(1 - math.cos(math.pi * (
(steps - timestamp.item()) / steps)**5.0)) / 2)
else:
guidance_scale = guide_scale
y_out = model(x=x_t, t=t, **model_kwargs[0])
u_out = model(x=x_t, t=t, **model_kwargs[1])
out = u_out + guidance_scale * (y_out - u_out)
if guide_rescale is not None and guide_rescale > 0.0:
ratio = (
y_out.flatten(1).std(dim=1) /
(out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
(y_out.ndim - 1))
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
if self.prediction_type == 'x0':
x0 = out
elif self.prediction_type == 'eps':
x0 = (x_t - sigma * out) / alpha
elif self.prediction_type == 'v':
x0 = alpha * x_t - sigma * out
else:
raise NotImplementedError(
f'prediction_type {self.prediction_type} not implemented')
return x0
sampler_ins = self.get_sampler(sampler)
# this is ignored for schnell
sampler_output = sampler_ins.preprare_sampler(
noise,
steps=steps,
prediction_type=self.prediction_type,
scheduler_ins=self.sampler_scheduler,
callback_fn=callback_fn)
for _ in trange(steps, disable=not show_progress):
trange.desc = sampler_output.msg
sampler_output = sampler_ins.step(sampler_output)
if return_intermediate == 'x_0':
intermediates.append(sampler_output.x_0)
elif return_intermediate == 'x_t':
intermediates.append(sampler_output.x_t)
if intermediate_callback is not None:
intermediate_callback(intermediates[-1])
return (sampler_output.x_0, intermediates
) if return_intermediate is not None else sampler_output.x_0
def loss(self,
x_0,
model,
model_kwargs={},
reduction='mean',
noise=None,
**kwargs):
# use noise scheduler to add noise
if noise is None:
noise = torch.randn_like(x_0)
schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
out = model(x=x_t, t=t, **model_kwargs)
# mse loss
target = {
'eps': noise,
'x0': x_0,
'v': alpha * noise - sigma * x_0
}[self.prediction_type]
loss = (out - target).pow(2)
if reduction == 'mean':
loss = loss.flatten(1).mean(dim=1)
if self.min_snr_gamma is not None:
alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
snrs = (alphas / sigmas).clamp(min=1e-20)
min_snrs = snrs.clamp(max=self.min_snr_gamma)
weights = min_snrs / snrs
else:
weights = 1
loss = loss * weights
return loss
def get_sampler(self, sampler):
if isinstance(sampler, str):
if sampler not in DIFFUSION_SAMPLERS.class_map:
if self.logger is not None:
self.logger.info(
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
)
else:
print(
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
)
return None
sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
logger=self.logger)
elif isinstance(sampler, (Config, dict, OrderedDict)):
if isinstance(sampler, (dict, OrderedDict)):
sampler = Config(
cfg_dict={k.upper(): v
for k, v in dict(sampler).items()},
load=False)
sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
else:
raise NotImplementedError
return sampler_ins
def __repr__(self) -> str:
return f'{self.__class__.__name__}' + ' ' + super().__repr__()
@staticmethod
def get_config_template():
return dict_to_yaml('DIFFUSIONS',
__class__.__name__,
ACEDiffusion.para_dict,
set_name=True)