Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import math | |
import os | |
from collections import OrderedDict | |
import torch | |
from tqdm import trange | |
from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS, | |
NOISE_SCHEDULERS) | |
from scepter.modules.utils.config import Config, dict_to_yaml | |
from scepter.modules.utils.distribute import we | |
from scepter.modules.utils.file_system import FS | |
class ACEDiffusion(object): | |
para_dict = { | |
'NOISE_SCHEDULER': {}, | |
'SAMPLER_SCHEDULER': {}, | |
'MIN_SNR_GAMMA': { | |
'value': None, | |
'description': 'The minimum SNR gamma value for the loss function.' | |
}, | |
'PREDICTION_TYPE': { | |
'value': 'eps', | |
'description': | |
'The type of prediction to use for the loss function.' | |
} | |
} | |
def __init__(self, cfg, logger=None): | |
super(ACEDiffusion, self).__init__() | |
self.logger = logger | |
self.cfg = cfg | |
self.init_params() | |
def init_params(self): | |
self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None) | |
self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps') | |
self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER, | |
logger=self.logger) | |
self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get( | |
'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER), | |
logger=self.logger) | |
self.num_timesteps = self.noise_scheduler.num_timesteps | |
if self.cfg.have('WORK_DIR') and we.rank == 0: | |
schedule_visualization = os.path.join(self.cfg.WORK_DIR, | |
'noise_schedule.png') | |
with FS.put_to(schedule_visualization) as local_path: | |
self.noise_scheduler.plot_noise_sampling_map(local_path) | |
schedule_visualization = os.path.join(self.cfg.WORK_DIR, | |
'sampler_schedule.png') | |
with FS.put_to(schedule_visualization) as local_path: | |
self.sampler_scheduler.plot_noise_sampling_map(local_path) | |
def sample(self, | |
noise, | |
model, | |
model_kwargs={}, | |
steps=20, | |
sampler=None, | |
use_dynamic_cfg=False, | |
guide_scale=None, | |
guide_rescale=None, | |
show_progress=False, | |
return_intermediate=None, | |
intermediate_callback=None, | |
**kwargs): | |
assert isinstance(steps, (int, torch.LongTensor)) | |
assert return_intermediate in (None, 'x0', 'xt') | |
assert isinstance(sampler, (str, dict, Config)) | |
intermediates = [] | |
def callback_fn(x_t, t, sigma=None, alpha=None): | |
timestamp = t | |
t = t.repeat(len(x_t)).round().long().to(x_t.device) | |
sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1))) | |
alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1))) | |
if guide_scale is None or guide_scale == 1.0: | |
out = model(x=x_t, t=t, **model_kwargs) | |
else: | |
if use_dynamic_cfg: | |
guidance_scale = 1 + guide_scale * ( | |
(1 - math.cos(math.pi * ( | |
(steps - timestamp.item()) / steps)**5.0)) / 2) | |
else: | |
guidance_scale = guide_scale | |
y_out = model(x=x_t, t=t, **model_kwargs[0]) | |
u_out = model(x=x_t, t=t, **model_kwargs[1]) | |
out = u_out + guidance_scale * (y_out - u_out) | |
if guide_rescale is not None and guide_rescale > 0.0: | |
ratio = ( | |
y_out.flatten(1).std(dim=1) / | |
(out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) * | |
(y_out.ndim - 1)) | |
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 | |
if self.prediction_type == 'x0': | |
x0 = out | |
elif self.prediction_type == 'eps': | |
x0 = (x_t - sigma * out) / alpha | |
elif self.prediction_type == 'v': | |
x0 = alpha * x_t - sigma * out | |
else: | |
raise NotImplementedError( | |
f'prediction_type {self.prediction_type} not implemented') | |
return x0 | |
sampler_ins = self.get_sampler(sampler) | |
# this is ignored for schnell | |
sampler_output = sampler_ins.preprare_sampler( | |
noise, | |
steps=steps, | |
prediction_type=self.prediction_type, | |
scheduler_ins=self.sampler_scheduler, | |
callback_fn=callback_fn) | |
for _ in trange(steps, disable=not show_progress): | |
trange.desc = sampler_output.msg | |
sampler_output = sampler_ins.step(sampler_output) | |
if return_intermediate == 'x_0': | |
intermediates.append(sampler_output.x_0) | |
elif return_intermediate == 'x_t': | |
intermediates.append(sampler_output.x_t) | |
if intermediate_callback is not None: | |
intermediate_callback(intermediates[-1]) | |
return (sampler_output.x_0, intermediates | |
) if return_intermediate is not None else sampler_output.x_0 | |
def loss(self, | |
x_0, | |
model, | |
model_kwargs={}, | |
reduction='mean', | |
noise=None, | |
**kwargs): | |
# use noise scheduler to add noise | |
if noise is None: | |
noise = torch.randn_like(x_0) | |
schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs) | |
x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha | |
out = model(x=x_t, t=t, **model_kwargs) | |
# mse loss | |
target = { | |
'eps': noise, | |
'x0': x_0, | |
'v': alpha * noise - sigma * x_0 | |
}[self.prediction_type] | |
loss = (out - target).pow(2) | |
if reduction == 'mean': | |
loss = loss.flatten(1).mean(dim=1) | |
if self.min_snr_gamma is not None: | |
alphas = self.noise_scheduler.alphas.to(x_0.device)[t] | |
sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t] | |
snrs = (alphas / sigmas).clamp(min=1e-20) | |
min_snrs = snrs.clamp(max=self.min_snr_gamma) | |
weights = min_snrs / snrs | |
else: | |
weights = 1 | |
loss = loss * weights | |
return loss | |
def get_sampler(self, sampler): | |
if isinstance(sampler, str): | |
if sampler not in DIFFUSION_SAMPLERS.class_map: | |
if self.logger is not None: | |
self.logger.info( | |
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}' | |
) | |
else: | |
print( | |
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}' | |
) | |
return None | |
sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False) | |
sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg, | |
logger=self.logger) | |
elif isinstance(sampler, (Config, dict, OrderedDict)): | |
if isinstance(sampler, (dict, OrderedDict)): | |
sampler = Config( | |
cfg_dict={k.upper(): v | |
for k, v in dict(sampler).items()}, | |
load=False) | |
sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger) | |
else: | |
raise NotImplementedError | |
return sampler_ins | |
def __repr__(self) -> str: | |
return f'{self.__class__.__name__}' + ' ' + super().__repr__() | |
def get_config_template(): | |
return dict_to_yaml('DIFFUSIONS', | |
__class__.__name__, | |
ACEDiffusion.para_dict, | |
set_name=True) |