from functools import partial import json import math from pathlib import Path from jsonmerge import merge from . import augmentation, layers, models, utils def round_to_power_of_two(x, tol): approxs = [] for i in range(math.ceil(math.log2(x))): mult = 2**i approxs.append(round(x / mult) * mult) for approx in reversed(approxs): error = abs((approx - x) / x) if error <= tol: return approx return approxs[0] def load_config(path_or_dict): defaults_image_v1 = { 'model': { 'patch_size': 1, 'augment_wrapper': True, 'mapping_cond_dim': 0, 'unet_cond_dim': 0, 'cross_cond_dim': 0, 'cross_attn_depths': None, 'skip_stages': 0, 'has_variance': False, }, 'optimizer': { 'type': 'adamw', 'lr': 1e-4, 'betas': [0.95, 0.999], 'eps': 1e-6, 'weight_decay': 1e-3, }, } defaults_image_transformer_v1 = { 'model': { 'd_ff': 0, 'augment_wrapper': False, 'skip_stages': 0, 'has_variance': False, }, 'optimizer': { 'type': 'adamw', 'lr': 5e-4, 'betas': [0.9, 0.99], 'eps': 1e-8, 'weight_decay': 1e-4, }, } defaults_image_transformer_v2 = { 'model': { 'mapping_width': 256, 'mapping_depth': 2, 'mapping_d_ff': None, 'mapping_cond_dim': 0, 'mapping_dropout_rate': 0., 'd_ffs': None, 'self_attns': None, 'dropout_rate': None, 'augment_wrapper': False, 'skip_stages': 0, 'has_variance': False, }, 'optimizer': { 'type': 'adamw', 'lr': 5e-4, 'betas': [0.9, 0.99], 'eps': 1e-8, 'weight_decay': 1e-4, }, } defaults = { 'model': { 'sigma_data': 1., 'dropout_rate': 0., 'augment_prob': 0., 'loss_config': 'karras', 'loss_weighting': 'karras', 'loss_scales': 1, }, 'dataset': { 'type': 'imagefolder', 'num_classes': 0, 'cond_dropout_rate': 0.1, }, 'optimizer': { 'type': 'adamw', 'lr': 1e-4, 'betas': [0.9, 0.999], 'eps': 1e-8, 'weight_decay': 1e-4, }, 'lr_sched': { 'type': 'constant', 'warmup': 0., }, 'ema_sched': { 'type': 'inverse', 'power': 0.6667, 'max_value': 0.9999 }, } if not isinstance(path_or_dict, dict): file = Path(path_or_dict) if file.suffix == '.safetensors': metadata = utils.get_safetensors_metadata(file) config = json.loads(metadata['config']) else: config = json.loads(file.read_text()) else: config = path_or_dict if config['model']['type'] == 'image_v1': config = merge(defaults_image_v1, config) elif config['model']['type'] == 'image_transformer_v1': config = merge(defaults_image_transformer_v1, config) if not config['model']['d_ff']: config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05) elif config['model']['type'] == 'image_transformer_v2': config = merge(defaults_image_transformer_v2, config) if not config['model']['mapping_d_ff']: config['model']['mapping_d_ff'] = config['model']['mapping_width'] * 3 if not config['model']['d_ffs']: d_ffs = [] for width in config['model']['widths']: d_ffs.append(width * 3) config['model']['d_ffs'] = d_ffs if not config['model']['self_attns']: self_attns = [] default_neighborhood = {"type": "neighborhood", "d_head": 64, "kernel_size": 7} default_global = {"type": "global", "d_head": 64} for i in range(len(config['model']['widths'])): self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global) config['model']['self_attns'] = self_attns if config['model']['dropout_rate'] is None: config['model']['dropout_rate'] = [0.0] * len(config['model']['widths']) elif isinstance(config['model']['dropout_rate'], float): config['model']['dropout_rate'] = [config['model']['dropout_rate']] * len(config['model']['widths']) return merge(defaults, config) def make_model(config): dataset_config = config['dataset'] num_classes = dataset_config['num_classes'] config = config['model'] if config['type'] == 'image_v1': model = models.ImageDenoiserModelV1( config['input_channels'], config['mapping_out'], config['depths'], config['channels'], config['self_attn_depths'], config['cross_attn_depths'], patch_size=config['patch_size'], dropout_rate=config['dropout_rate'], mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), unet_cond_dim=config['unet_cond_dim'], cross_cond_dim=config['cross_cond_dim'], skip_stages=config['skip_stages'], has_variance=config['has_variance'], ) if config['augment_wrapper']: model = augmentation.KarrasAugmentWrapper(model) elif config['type'] == 'image_transformer_v1': model = models.ImageTransformerDenoiserModelV1( n_layers=config['depth'], d_model=config['width'], d_ff=config['d_ff'], in_features=config['input_channels'], out_features=config['input_channels'], patch_size=config['patch_size'], num_classes=num_classes + 1 if num_classes else 0, dropout=config['dropout_rate'], sigma_data=config['sigma_data'], ) elif config['type'] == 'image_transformer_v2': assert len(config['widths']) == len(config['depths']) assert len(config['widths']) == len(config['d_ffs']) assert len(config['widths']) == len(config['self_attns']) assert len(config['widths']) == len(config['dropout_rate']) levels = [] for depth, width, d_ff, self_attn, dropout in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'], config['dropout_rate']): if self_attn['type'] == 'global': self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) elif self_attn['type'] == 'neighborhood': self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) elif self_attn['type'] == 'shifted-window': self_attn = models.image_transformer_v2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) elif self_attn['type'] == 'none': self_attn = models.image_transformer_v2.NoAttentionSpec() else: raise ValueError(f'unsupported self attention type {self_attn["type"]}') levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn, dropout)) mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff'], config['mapping_dropout_rate']) model = models.ImageTransformerDenoiserModelV2( levels=levels, mapping=mapping, in_channels=config['input_channels'], out_channels=config['input_channels'], patch_size=config['patch_size'], num_classes=num_classes + 1 if num_classes else 0, mapping_cond_dim=config['mapping_cond_dim'], ) else: raise ValueError(f'unsupported model type {config["type"]}') return model def make_denoiser_wrapper(config): config = config['model'] sigma_data = config.get('sigma_data', 1.) has_variance = config.get('has_variance', False) loss_config = config.get('loss_config', 'karras') if loss_config == 'karras': weighting = config.get('loss_weighting', 'karras') scales = config.get('loss_scales', 1) if not has_variance: return partial(layers.Denoiser, sigma_data=sigma_data, weighting=weighting, scales=scales) return partial(layers.DenoiserWithVariance, sigma_data=sigma_data, weighting=weighting) if loss_config == 'simple': if has_variance: raise ValueError('Simple loss config does not support a variance output') return partial(layers.SimpleLossDenoiser, sigma_data=sigma_data) raise ValueError('Unknown loss config type') def make_sample_density(config): sd_config = config['sigma_sample_density'] sigma_data = config['sigma_data'] if sd_config['type'] == 'lognormal': loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] return partial(utils.rand_log_normal, loc=loc, scale=scale) if sd_config['type'] == 'loglogistic': loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) scale = sd_config['scale'] if 'scale' in sd_config else 0.5 min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) if sd_config['type'] == 'loguniform': min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) if sd_config['type'] in {'v-diffusion', 'cosine'}: min_value = sd_config['min_value'] if 'min_value' in sd_config else 1e-3 max_value = sd_config['max_value'] if 'max_value' in sd_config else 1e3 return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) if sd_config['type'] == 'split-lognormal': loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) if sd_config['type'] == 'cosine-interpolated': min_value = sd_config.get('min_value', min(config['sigma_min'], 1e-3)) max_value = sd_config.get('max_value', max(config['sigma_max'], 1e3)) image_d = sd_config.get('image_d', max(config['input_size'])) noise_d_low = sd_config.get('noise_d_low', 32) noise_d_high = sd_config.get('noise_d_high', max(config['input_size'])) return partial(utils.rand_cosine_interpolated, image_d=image_d, noise_d_low=noise_d_low, noise_d_high=noise_d_high, sigma_data=sigma_data, min_value=min_value, max_value=max_value) raise ValueError('Unknown sample density type')