# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import math import os from collections import OrderedDict import torch from tqdm import trange from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS, NOISE_SCHEDULERS) from scepter.modules.utils.config import Config, dict_to_yaml from scepter.modules.utils.distribute import we from scepter.modules.utils.file_system import FS @DIFFUSIONS.register_class() class ACEDiffusion(object): para_dict = { 'NOISE_SCHEDULER': {}, 'SAMPLER_SCHEDULER': {}, 'MIN_SNR_GAMMA': { 'value': None, 'description': 'The minimum SNR gamma value for the loss function.' }, 'PREDICTION_TYPE': { 'value': 'eps', 'description': 'The type of prediction to use for the loss function.' } } def __init__(self, cfg, logger=None): super(ACEDiffusion, self).__init__() self.logger = logger self.cfg = cfg self.init_params() def init_params(self): self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None) self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps') self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER, logger=self.logger) self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get( 'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER), logger=self.logger) self.num_timesteps = self.noise_scheduler.num_timesteps if self.cfg.have('WORK_DIR') and we.rank == 0: schedule_visualization = os.path.join(self.cfg.WORK_DIR, 'noise_schedule.png') with FS.put_to(schedule_visualization) as local_path: self.noise_scheduler.plot_noise_sampling_map(local_path) schedule_visualization = os.path.join(self.cfg.WORK_DIR, 'sampler_schedule.png') with FS.put_to(schedule_visualization) as local_path: self.sampler_scheduler.plot_noise_sampling_map(local_path) def sample(self, noise, model, model_kwargs={}, steps=20, sampler=None, use_dynamic_cfg=False, guide_scale=None, guide_rescale=None, show_progress=False, return_intermediate=None, intermediate_callback=None, **kwargs): assert isinstance(steps, (int, torch.LongTensor)) assert return_intermediate in (None, 'x0', 'xt') assert isinstance(sampler, (str, dict, Config)) intermediates = [] def callback_fn(x_t, t, sigma=None, alpha=None): timestamp = t t = t.repeat(len(x_t)).round().long().to(x_t.device) sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1))) alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1))) if guide_scale is None or guide_scale == 1.0: out = model(x=x_t, t=t, **model_kwargs) else: if use_dynamic_cfg: guidance_scale = 1 + guide_scale * ( (1 - math.cos(math.pi * ( (steps - timestamp.item()) / steps)**5.0)) / 2) else: guidance_scale = guide_scale y_out = model(x=x_t, t=t, **model_kwargs[0]) u_out = model(x=x_t, t=t, **model_kwargs[1]) out = u_out + guidance_scale * (y_out - u_out) if guide_rescale is not None and guide_rescale > 0.0: ratio = ( y_out.flatten(1).std(dim=1) / (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) * (y_out.ndim - 1)) out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 if self.prediction_type == 'x0': x0 = out elif self.prediction_type == 'eps': x0 = (x_t - sigma * out) / alpha elif self.prediction_type == 'v': x0 = alpha * x_t - sigma * out else: raise NotImplementedError( f'prediction_type {self.prediction_type} not implemented') return x0 sampler_ins = self.get_sampler(sampler) # this is ignored for schnell sampler_output = sampler_ins.preprare_sampler( noise, steps=steps, prediction_type=self.prediction_type, scheduler_ins=self.sampler_scheduler, callback_fn=callback_fn) for _ in trange(steps, disable=not show_progress): trange.desc = sampler_output.msg sampler_output = sampler_ins.step(sampler_output) if return_intermediate == 'x_0': intermediates.append(sampler_output.x_0) elif return_intermediate == 'x_t': intermediates.append(sampler_output.x_t) if intermediate_callback is not None: intermediate_callback(intermediates[-1]) return (sampler_output.x_0, intermediates ) if return_intermediate is not None else sampler_output.x_0 def loss(self, x_0, model, model_kwargs={}, reduction='mean', noise=None, **kwargs): # use noise scheduler to add noise if noise is None: noise = torch.randn_like(x_0) schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs) x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha out = model(x=x_t, t=t, **model_kwargs) # mse loss target = { 'eps': noise, 'x0': x_0, 'v': alpha * noise - sigma * x_0 }[self.prediction_type] loss = (out - target).pow(2) if reduction == 'mean': loss = loss.flatten(1).mean(dim=1) if self.min_snr_gamma is not None: alphas = self.noise_scheduler.alphas.to(x_0.device)[t] sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t] snrs = (alphas / sigmas).clamp(min=1e-20) min_snrs = snrs.clamp(max=self.min_snr_gamma) weights = min_snrs / snrs else: weights = 1 loss = loss * weights return loss def get_sampler(self, sampler): if isinstance(sampler, str): if sampler not in DIFFUSION_SAMPLERS.class_map: if self.logger is not None: self.logger.info( f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}' ) else: print( f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}' ) return None sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False) sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg, logger=self.logger) elif isinstance(sampler, (Config, dict, OrderedDict)): if isinstance(sampler, (dict, OrderedDict)): sampler = Config( cfg_dict={k.upper(): v for k, v in dict(sampler).items()}, load=False) sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger) else: raise NotImplementedError return sampler_ins def __repr__(self) -> str: return f'{self.__class__.__name__}' + ' ' + super().__repr__() @staticmethod def get_config_template(): return dict_to_yaml('DIFFUSIONS', __class__.__name__, ACEDiffusion.para_dict, set_name=True)