bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
No virus
11.7 kB
from functools import partial
import json
import math
from pathlib import Path
from jsonmerge import merge
from . import augmentation, layers, models, utils
def round_to_power_of_two(x, tol):
approxs = []
for i in range(math.ceil(math.log2(x))):
mult = 2**i
approxs.append(round(x / mult) * mult)
for approx in reversed(approxs):
error = abs((approx - x) / x)
if error <= tol:
return approx
return approxs[0]
def load_config(path_or_dict):
defaults_image_v1 = {
'model': {
'patch_size': 1,
'augment_wrapper': True,
'mapping_cond_dim': 0,
'unet_cond_dim': 0,
'cross_cond_dim': 0,
'cross_attn_depths': None,
'skip_stages': 0,
'has_variance': False,
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.95, 0.999],
'eps': 1e-6,
'weight_decay': 1e-3,
},
}
defaults_image_transformer_v1 = {
'model': {
'd_ff': 0,
'augment_wrapper': False,
'skip_stages': 0,
'has_variance': False,
},
'optimizer': {
'type': 'adamw',
'lr': 5e-4,
'betas': [0.9, 0.99],
'eps': 1e-8,
'weight_decay': 1e-4,
},
}
defaults_image_transformer_v2 = {
'model': {
'mapping_width': 256,
'mapping_depth': 2,
'mapping_d_ff': None,
'mapping_cond_dim': 0,
'mapping_dropout_rate': 0.,
'd_ffs': None,
'self_attns': None,
'dropout_rate': None,
'augment_wrapper': False,
'skip_stages': 0,
'has_variance': False,
},
'optimizer': {
'type': 'adamw',
'lr': 5e-4,
'betas': [0.9, 0.99],
'eps': 1e-8,
'weight_decay': 1e-4,
},
}
defaults = {
'model': {
'sigma_data': 1.,
'dropout_rate': 0.,
'augment_prob': 0.,
'loss_config': 'karras',
'loss_weighting': 'karras',
'loss_scales': 1,
},
'dataset': {
'type': 'imagefolder',
'num_classes': 0,
'cond_dropout_rate': 0.1,
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.9, 0.999],
'eps': 1e-8,
'weight_decay': 1e-4,
},
'lr_sched': {
'type': 'constant',
'warmup': 0.,
},
'ema_sched': {
'type': 'inverse',
'power': 0.6667,
'max_value': 0.9999
},
}
if not isinstance(path_or_dict, dict):
file = Path(path_or_dict)
if file.suffix == '.safetensors':
metadata = utils.get_safetensors_metadata(file)
config = json.loads(metadata['config'])
else:
config = json.loads(file.read_text())
else:
config = path_or_dict
if config['model']['type'] == 'image_v1':
config = merge(defaults_image_v1, config)
elif config['model']['type'] == 'image_transformer_v1':
config = merge(defaults_image_transformer_v1, config)
if not config['model']['d_ff']:
config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05)
elif config['model']['type'] == 'image_transformer_v2':
config = merge(defaults_image_transformer_v2, config)
if not config['model']['mapping_d_ff']:
config['model']['mapping_d_ff'] = config['model']['mapping_width'] * 3
if not config['model']['d_ffs']:
d_ffs = []
for width in config['model']['widths']:
d_ffs.append(width * 3)
config['model']['d_ffs'] = d_ffs
if not config['model']['self_attns']:
self_attns = []
default_neighborhood = {"type": "neighborhood", "d_head": 64, "kernel_size": 7}
default_global = {"type": "global", "d_head": 64}
for i in range(len(config['model']['widths'])):
self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global)
config['model']['self_attns'] = self_attns
if config['model']['dropout_rate'] is None:
config['model']['dropout_rate'] = [0.0] * len(config['model']['widths'])
elif isinstance(config['model']['dropout_rate'], float):
config['model']['dropout_rate'] = [config['model']['dropout_rate']] * len(config['model']['widths'])
return merge(defaults, config)
def make_model(config):
dataset_config = config['dataset']
num_classes = dataset_config['num_classes']
config = config['model']
if config['type'] == 'image_v1':
model = models.ImageDenoiserModelV1(
config['input_channels'],
config['mapping_out'],
config['depths'],
config['channels'],
config['self_attn_depths'],
config['cross_attn_depths'],
patch_size=config['patch_size'],
dropout_rate=config['dropout_rate'],
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
unet_cond_dim=config['unet_cond_dim'],
cross_cond_dim=config['cross_cond_dim'],
skip_stages=config['skip_stages'],
has_variance=config['has_variance'],
)
if config['augment_wrapper']:
model = augmentation.KarrasAugmentWrapper(model)
elif config['type'] == 'image_transformer_v1':
model = models.ImageTransformerDenoiserModelV1(
n_layers=config['depth'],
d_model=config['width'],
d_ff=config['d_ff'],
in_features=config['input_channels'],
out_features=config['input_channels'],
patch_size=config['patch_size'],
num_classes=num_classes + 1 if num_classes else 0,
dropout=config['dropout_rate'],
sigma_data=config['sigma_data'],
)
elif config['type'] == 'image_transformer_v2':
assert len(config['widths']) == len(config['depths'])
assert len(config['widths']) == len(config['d_ffs'])
assert len(config['widths']) == len(config['self_attns'])
assert len(config['widths']) == len(config['dropout_rate'])
levels = []
for depth, width, d_ff, self_attn, dropout in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'], config['dropout_rate']):
if self_attn['type'] == 'global':
self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64))
elif self_attn['type'] == 'neighborhood':
self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
elif self_attn['type'] == 'shifted-window':
self_attn = models.image_transformer_v2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
elif self_attn['type'] == 'none':
self_attn = models.image_transformer_v2.NoAttentionSpec()
else:
raise ValueError(f'unsupported self attention type {self_attn["type"]}')
levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn, dropout))
mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff'], config['mapping_dropout_rate'])
model = models.ImageTransformerDenoiserModelV2(
levels=levels,
mapping=mapping,
in_channels=config['input_channels'],
out_channels=config['input_channels'],
patch_size=config['patch_size'],
num_classes=num_classes + 1 if num_classes else 0,
mapping_cond_dim=config['mapping_cond_dim'],
)
else:
raise ValueError(f'unsupported model type {config["type"]}')
return model
def make_denoiser_wrapper(config):
config = config['model']
sigma_data = config.get('sigma_data', 1.)
has_variance = config.get('has_variance', False)
loss_config = config.get('loss_config', 'karras')
if loss_config == 'karras':
weighting = config.get('loss_weighting', 'karras')
scales = config.get('loss_scales', 1)
if not has_variance:
return partial(layers.Denoiser, sigma_data=sigma_data, weighting=weighting, scales=scales)
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data, weighting=weighting)
if loss_config == 'simple':
if has_variance:
raise ValueError('Simple loss config does not support a variance output')
return partial(layers.SimpleLossDenoiser, sigma_data=sigma_data)
raise ValueError('Unknown loss config type')
def make_sample_density(config):
sd_config = config['sigma_sample_density']
sigma_data = config['sigma_data']
if sd_config['type'] == 'lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
return partial(utils.rand_log_normal, loc=loc, scale=scale)
if sd_config['type'] == 'loglogistic':
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'loguniform':
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
if sd_config['type'] in {'v-diffusion', 'cosine'}:
min_value = sd_config['min_value'] if 'min_value' in sd_config else 1e-3
max_value = sd_config['max_value'] if 'max_value' in sd_config else 1e3
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'split-lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
if sd_config['type'] == 'cosine-interpolated':
min_value = sd_config.get('min_value', min(config['sigma_min'], 1e-3))
max_value = sd_config.get('max_value', max(config['sigma_max'], 1e3))
image_d = sd_config.get('image_d', max(config['input_size']))
noise_d_low = sd_config.get('noise_d_low', 32)
noise_d_high = sd_config.get('noise_d_high', max(config['input_size']))
return partial(utils.rand_cosine_interpolated, image_d=image_d, noise_d_low=noise_d_low, noise_d_high=noise_d_high, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
raise ValueError('Unknown sample density type')