Spaces:
Runtime error
Runtime error
import datetime | |
import gc | |
import html | |
import json | |
import os | |
import sys | |
import time | |
import traceback | |
from collections import defaultdict, deque | |
import torch | |
import tqdm | |
from modules import shared, sd_models, devices, processing, sd_samplers | |
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork | |
from modules.textual_inversion import textual_inversion | |
from modules.textual_inversion.learn_schedule import LearnRateScheduler | |
from ..tbutils import tensorboard_setup, tensorboard_add, tensorboard_add_image, tensorboard_log_hyperparameter | |
from .textual_inversion import validate_train_inputs, write_loss | |
from ..hypernetwork import Hypernetwork, load_hypernetwork | |
from . import sd_hijack_checkpoint | |
from ..hnutil import optim_to | |
from ..ui import create_hypernetwork_load | |
from ..scheduler import CosineAnnealingWarmUpRestarts | |
from .dataset import PersonalizedBase, PersonalizedDataLoader | |
from ..ddpm_hijack import set_scheduler | |
def get_lr_from_optimizer(optimizer: torch.optim.Optimizer): | |
return optimizer.param_groups[0].get('d', 1) * optimizer.param_groups[0].get('lr', 1) | |
def set_accessible(obj): | |
setattr(shared, 'accessible_hypernetwork', obj) | |
if hasattr(shared, 'loaded_hypernetworks'): | |
shared.loaded_hypernetworks.clear() | |
shared.loaded_hypernetworks = [obj,] | |
def remove_accessible(): | |
delattr(shared, 'accessible_hypernetwork') | |
if hasattr(shared, 'loaded_hypernetworks'): | |
shared.loaded_hypernetworks.clear() | |
def get_training_option(filename): | |
print(filename) | |
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile( | |
os.path.join(shared.cmd_opts.hypernetwork_dir, filename)): | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename) | |
elif os.path.exists(filename) and os.path.isfile(filename): | |
filename = filename | |
elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile( | |
os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')): | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json') | |
else: | |
return False | |
print(f"Loading setting from {filename}!") | |
with open(filename, 'r') as file: | |
obj = json.load(file) | |
return obj | |
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1, use_adamw_parameter=False, use_dadaptation=False, dadapt_growth_factor=-1, **adamW_kwarg_dict): | |
""" returns hypernetwork object binded with optimizer""" | |
hypernetwork = load_hypernetwork(hypernetwork_name) | |
hypernetwork.to(devices.device) | |
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!" | |
if not isinstance(hypernetwork, Hypernetwork): | |
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") | |
set_accessible(hypernetwork) | |
weights = hypernetwork.weights(True) | |
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') | |
# Here we use optimizer from saved HN, or we can specify as UI option. | |
if hypernetwork.optimizer_name == 'DAdaptAdamW': | |
use_dadaptation = True | |
optimizer = None | |
optimizer_name = 'AdamW' | |
# Here we use optimizer from saved HN, or we can specify as UI option. | |
if hypernetwork.optimizer_name in optimizer_dict: | |
if use_adamw_parameter: | |
if hypernetwork.optimizer_name != 'AdamW' and hypernetwork.optimizer_name != 'DAdaptAdamW': | |
raise NotImplementedError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!") | |
if use_dadaptation: | |
from .dadapt_test.install import get_dadapt_adam | |
optim_class = get_dadapt_adam(hypernetwork.optimizer_name) | |
if optim_class != torch.optim.AdamW: | |
print('Optimizer class is ' + str(optim_class)) | |
optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, **adamW_kwarg_dict) | |
hypernetwork.optimizer_name = 'DAdaptAdamW' | |
else: | |
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict) | |
else: | |
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict) | |
else: | |
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=learn_rate) | |
optimizer_name = hypernetwork.optimizer_name | |
else: | |
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") | |
if use_dadaptation: | |
from .dadapt_test.install import get_dadapt_adam | |
optim_class = get_dadapt_adam(hypernetwork.optimizer_name) | |
if optim_class != torch.optim.AdamW: | |
optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, **adamW_kwarg_dict) | |
optimizer_name = 'DAdaptAdamW' | |
hypernetwork.optimizer_name = 'DAdaptAdamW' | |
if optimizer is None: | |
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict) | |
optimizer_name = 'AdamW' | |
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. | |
try: | |
optimizer.load_state_dict(hypernetwork.optimizer_state_dict) | |
optim_to(optimizer, devices.device) | |
print('Loaded optimizer successfully!') | |
except RuntimeError as e: | |
print("Cannot resume from saved optimizer!") | |
print(e) | |
return hypernetwork, optimizer, weights, optimizer_name | |
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, | |
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, | |
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, | |
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, | |
preview_width, preview_height, | |
use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1, warmup=10, min_lr=1e-7, | |
gamma_rate=1, save_when_converge=False, create_when_converge=False, | |
move_optimizer=True, | |
use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99, | |
adamw_eps=1e-8, | |
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, | |
optional_gradient_norm_type=2, latent_sampling_std=-1, | |
noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128, | |
load_training_options='', loss_opt='loss_simple', use_dadaptation=False, dadapt_growth_factor=-1, use_weight=False | |
): | |
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. | |
from modules import images | |
if load_training_options != '': | |
dump: dict = get_training_option(load_training_options) | |
if dump and dump is not None: | |
print(f"Loading from {load_training_options}") | |
learn_rate = dump['learn_rate'] | |
batch_size = dump['batch_size'] | |
gradient_step = dump['gradient_step'] | |
training_width = dump['training_width'] | |
training_height = dump['training_height'] | |
steps = dump['steps'] | |
shuffle_tags = dump['shuffle_tags'] | |
tag_drop_out = dump['tag_drop_out'] | |
save_when_converge = dump['save_when_converge'] | |
create_when_converge = dump['create_when_converge'] | |
latent_sampling_method = dump['latent_sampling_method'] | |
template_file = dump['template_file'] | |
use_beta_scheduler = dump['use_beta_scheduler'] | |
beta_repeat_epoch = dump['beta_repeat_epoch'] | |
epoch_mult = dump['epoch_mult'] | |
warmup = dump['warmup'] | |
min_lr = dump['min_lr'] | |
gamma_rate = dump['gamma_rate'] | |
use_adamw_parameter = dump['use_beta_adamW_checkbox'] | |
adamw_weight_decay = dump['adamw_weight_decay'] | |
adamw_beta_1 = dump['adamw_beta_1'] | |
adamw_beta_2 = dump['adamw_beta_2'] | |
adamw_eps = dump['adamw_eps'] | |
use_grad_opts = dump['show_gradient_clip_checkbox'] | |
gradient_clip_opt = dump['gradient_clip_opt'] | |
optional_gradient_clip_value = dump['optional_gradient_clip_value'] | |
optional_gradient_norm_type = dump['optional_gradient_norm_type'] | |
latent_sampling_std = dump.get('latent_sampling_std', -1) | |
noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) | |
noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) | |
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) | |
loss_opt = dump.get('loss_opt', 'loss_simple') | |
use_dadaptation = dump.get('use_dadaptation', False) | |
dadapt_growth_factor = dump.get('dadapt_growth_factor', -1) | |
use_weight = dump.get('use_weight', False) | |
try: | |
if use_adamw_parameter: | |
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in | |
[adamw_weight_decay, adamw_beta_1, | |
adamw_beta_2, adamw_eps]] | |
assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!" | |
assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, | |
adamw_eps])), "Cannot use negative or >1 number for adamW parameters!" | |
adamW_kwarg_dict = { | |
'weight_decay': adamw_weight_decay, | |
'betas': (adamw_beta_1, adamw_beta_2), | |
'eps': adamw_eps | |
} | |
print('Using custom AdamW parameters') | |
else: | |
adamW_kwarg_dict = { | |
'weight_decay': 0.01, | |
'betas': (0.9, 0.99), | |
'eps': 1e-8 | |
} | |
if use_beta_scheduler: | |
print("Using Beta Scheduler") | |
beta_repeat_epoch = int(float(beta_repeat_epoch)) | |
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!" | |
min_lr = float(min_lr) | |
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!" | |
gamma_rate = float(gamma_rate) | |
print(f"Using learn rate decay(per cycle) of {gamma_rate}") | |
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!" | |
epoch_mult = float(epoch_mult) | |
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!" | |
warmup = int(float(warmup)) | |
assert warmup >= 1, "Warmup epoch should be larger than 0!" | |
print(f"Save when converges : {save_when_converge}") | |
print(f"Generate image when converges : {create_when_converge}") | |
else: | |
beta_repeat_epoch = 4000 | |
epoch_mult = 1 | |
warmup = 10 | |
min_lr = 1e-7 | |
gamma_rate = 1 | |
save_when_converge = False | |
create_when_converge = False | |
except ValueError as e: | |
raise RuntimeError("Cannot use advanced LR scheduler settings! "+ str(e)) | |
if noise_training_scheduler_enabled: | |
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) | |
print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!") | |
else: | |
set_scheduler(-1, False, False) | |
if use_grad_opts and gradient_clip_opt != "None": | |
try: | |
optional_gradient_clip_value = float(optional_gradient_clip_value) | |
except ValueError: | |
raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})") | |
if gradient_clip_opt == "Norm": | |
try: | |
grad_norm = int(float(optional_gradient_norm_type)) | |
except ValueError: | |
raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})") | |
assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}" | |
print( | |
f"Using gradient clipping by Norm, norm type {optional_gradient_norm_type}, norm limit {optional_gradient_clip_value}") | |
def gradient_clipping(arg1): | |
torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type) | |
return | |
else: | |
print(f"Using gradient clipping by Value, limit {optional_gradient_clip_value}") | |
def gradient_clipping(arg1): | |
torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value) | |
return | |
else: | |
def gradient_clipping(arg1): | |
return | |
save_hypernetwork_every = save_hypernetwork_every or 0 | |
create_image_every = create_image_every or 0 | |
if not os.path.isfile(template_file): | |
template_file = textual_inversion.textual_inversion_templates.get(template_file, None) | |
if template_file is not None: | |
template_file = template_file.path | |
else: | |
raise AssertionError(f"Cannot find {template_file}!") | |
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") | |
shared.state.job = "train-hypernetwork" | |
shared.state.textinfo = "Initializing hypernetwork training..." | |
shared.state.job_count = steps | |
tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0) | |
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, use_dadaptation,dadapt_growth_factor, **adamW_kwarg_dict) | |
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') | |
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) | |
unload = shared.opts.unload_models_when_training | |
if save_hypernetwork_every > 0 or save_when_converge: | |
hypernetwork_dir = os.path.join(log_directory, "hypernetworks") | |
os.makedirs(hypernetwork_dir, exist_ok=True) | |
else: | |
hypernetwork_dir = None | |
if create_image_every > 0 or create_when_converge: | |
images_dir = os.path.join(log_directory, "images") | |
os.makedirs(images_dir, exist_ok=True) | |
else: | |
images_dir = None | |
checkpoint = sd_models.select_checkpoint() | |
initial_step = hypernetwork.step or 0 | |
if initial_step >= steps: | |
shared.state.textinfo = f"Model has already been trained beyond specified max steps" | |
return hypernetwork, filename | |
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) | |
if shared.opts.training_enable_tensorboard: | |
print("Tensorboard logging enabled") | |
tensorboard_writer = tensorboard_setup(log_directory) | |
else: | |
tensorboard_writer = None | |
# dataset loading may take a while, so input validations and early returns should be done before this | |
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." | |
detach_grad = shared.opts.disable_ema # test code that removes EMA | |
if detach_grad: | |
print("Disabling training for staged models!") | |
shared.sd_model.cond_stage_model.requires_grad_(False) | |
shared.sd_model.first_stage_model.requires_grad_(False) | |
torch.cuda.empty_cache() | |
pin_memory = shared.opts.pin_memory | |
ds = PersonalizedBase(data_root=data_root, width=training_width, | |
height=training_height, | |
repeats=shared.opts.training_image_repeats_per_epoch, | |
placeholder_token=hypernetwork_name, model=shared.sd_model, | |
cond_model=shared.sd_model.cond_stage_model, | |
device=devices.device, template_file=template_file, | |
include_cond=True, batch_size=batch_size, | |
gradient_step=gradient_step, shuffle_tags=shuffle_tags, | |
tag_drop_out=tag_drop_out, | |
latent_sampling_method=latent_sampling_method, | |
latent_sampling_std=latent_sampling_std, | |
use_weight=use_weight) | |
latent_sampling_method = ds.latent_sampling_method | |
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, | |
batch_size=ds.batch_size, pin_memory=pin_memory) | |
old_parallel_processing_allowed = shared.parallel_processing_allowed | |
if unload: | |
shared.parallel_processing_allowed = False | |
shared.sd_model.cond_stage_model.to(devices.cpu) | |
shared.sd_model.first_stage_model.to(devices.cpu) | |
if use_beta_scheduler: | |
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, | |
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, | |
warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate) | |
scheduler_beta.last_epoch = hypernetwork.step - 1 | |
else: | |
scheduler_beta = None | |
for pg in optimizer.param_groups: | |
pg['lr'] = scheduler.learn_rate | |
scaler = torch.cuda.amp.GradScaler() | |
batch_size = ds.batch_size | |
gradient_step = ds.gradient_step | |
# n steps = batch_size * gradient_step * n image processed | |
steps_per_epoch = len(ds) // batch_size // gradient_step | |
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step | |
loss_step = 0 | |
_loss_step = 0 # internal | |
# size = len(ds.indexes) | |
loss_dict = defaultdict(lambda: deque(maxlen=1024)) | |
# losses = torch.zeros((size,)) | |
# previous_mean_losses = [0] | |
# previous_mean_loss = 0 | |
# print("Mean loss of {} elements".format(size)) | |
steps_without_grad = 0 | |
last_saved_file = "<none>" | |
last_saved_image = "<none>" | |
forced_filename = "<none>" | |
if hasattr(sd_hijack_checkpoint, 'add'): | |
sd_hijack_checkpoint.add() | |
pbar = tqdm.tqdm(total=steps - initial_step) | |
try: | |
for i in range((steps - initial_step) * gradient_step): | |
if scheduler.finished or hypernetwork.step > steps: | |
break | |
if shared.state.interrupted: | |
break | |
for j, batch in enumerate(dl): | |
# works as a drop_last=True for gradient accumulation | |
if j == max_steps_per_epoch: | |
break | |
if use_beta_scheduler: | |
scheduler_beta.step(hypernetwork.step) | |
else: | |
scheduler.apply(optimizer, hypernetwork.step) | |
if scheduler.finished: | |
break | |
if shared.state.interrupted: | |
break | |
with torch.autocast("cuda"): | |
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) | |
if use_weight: | |
w = batch.weight.to(devices.device, non_blocking=pin_memory) | |
if tag_drop_out != 0 or shuffle_tags: | |
shared.sd_model.cond_stage_model.to(devices.device) | |
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, | |
non_blocking=pin_memory) | |
shared.sd_model.cond_stage_model.to(devices.cpu) | |
else: | |
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) | |
if use_weight: | |
loss = shared.sd_model.weighted_forward(x, c, w)[0] | |
else: | |
_, losses = shared.sd_model.forward(x, c) | |
loss = losses['val/' + loss_opt] | |
for filenames in batch.filename: | |
loss_dict[filenames].append(loss.detach().item()) | |
loss /= gradient_step | |
assert not torch.isnan(loss), "Loss is NaN" | |
del x | |
del c | |
_loss_step += loss.item() | |
scaler.scale(loss).backward() | |
batch.latent_sample.to(devices.cpu) | |
# go back until we reach gradient accumulation steps | |
if (j + 1) % gradient_step != 0: | |
continue | |
gradient_clipping(weights) | |
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") | |
# scaler.unscale_(optimizer) | |
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") | |
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) | |
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") | |
try: | |
scaler.step(optimizer) | |
except AssertionError: | |
optimizer.param_groups[0]['capturable'] = True | |
scaler.step(optimizer) | |
scaler.update() | |
hypernetwork.step += 1 | |
pbar.update() | |
optimizer.zero_grad(set_to_none=True) | |
loss_step = _loss_step | |
_loss_step = 0 | |
steps_done = hypernetwork.step + 1 | |
epoch_num = hypernetwork.step // steps_per_epoch | |
epoch_step = hypernetwork.step % steps_per_epoch | |
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}" | |
pbar.set_description(description) | |
if hypernetwork_dir is not None and ( | |
(use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and save_when_converge) or ( | |
save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)): | |
# Before saving, change name to match current checkpoint. | |
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' | |
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') | |
hypernetwork.optimizer_name = optimizer_name | |
if shared.opts.save_optimizer_state: | |
hypernetwork.optimizer_state_dict = optimizer.state_dict() | |
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) | |
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. | |
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, | |
{ | |
"loss": f"{loss_step:.7f}", | |
"learn_rate": get_lr_from_optimizer(optimizer) | |
}) | |
if shared.opts.training_enable_tensorboard: | |
epoch_num = hypernetwork.step // len(ds) | |
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 | |
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) | |
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, | |
learn_rate=scheduler.learn_rate if not use_beta_scheduler else | |
get_lr_from_optimizer(optimizer), epoch_num=epoch_num) | |
if images_dir is not None and ( | |
use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( | |
create_image_every > 0 and steps_done % create_image_every == 0): | |
set_scheduler(-1, False, False) | |
forced_filename = f'{hypernetwork_name}-{steps_done}' | |
last_saved_image = os.path.join(images_dir, forced_filename) | |
rng_state = torch.get_rng_state() | |
cuda_rng_state = None | |
if torch.cuda.is_available(): | |
cuda_rng_state = torch.cuda.get_rng_state_all() | |
hypernetwork.eval() | |
if move_optimizer: | |
optim_to(optimizer, devices.cpu) | |
gc.collect() | |
shared.sd_model.cond_stage_model.to(devices.device) | |
shared.sd_model.first_stage_model.to(devices.device) | |
p = processing.StableDiffusionProcessingTxt2Img( | |
sd_model=shared.sd_model, | |
do_not_save_grid=True, | |
do_not_save_samples=True, | |
) | |
if hasattr(p, 'disable_extra_networks'): | |
p.disable_extra_networks = True | |
is_patched = True | |
else: | |
is_patched = False | |
if preview_from_txt2img: | |
p.prompt = preview_prompt + (hypernetwork.extra_name() if not is_patched else "") | |
print(p.prompt) | |
p.negative_prompt = preview_negative_prompt | |
p.steps = preview_steps | |
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name | |
p.cfg_scale = preview_cfg_scale | |
p.seed = preview_seed | |
p.width = preview_width | |
p.height = preview_height | |
else: | |
p.prompt = batch.cond_text[0] + (hypernetwork.extra_name() if not is_patched else "") | |
p.steps = 20 | |
p.width = training_width | |
p.height = training_height | |
preview_text = p.prompt | |
processed = processing.process_images(p) | |
image = processed.images[0] if len(processed.images) > 0 else None | |
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: | |
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, | |
hypernetwork.step) | |
if unload: | |
shared.sd_model.cond_stage_model.to(devices.cpu) | |
shared.sd_model.first_stage_model.to(devices.cpu) | |
torch.set_rng_state(rng_state) | |
if torch.cuda.is_available(): | |
torch.cuda.set_rng_state_all(cuda_rng_state) | |
hypernetwork.train() | |
if move_optimizer: | |
optim_to(optimizer, devices.device) | |
if noise_training_scheduler_enabled: | |
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) | |
if image is not None: | |
if hasattr(shared.state, 'assign_current_image'): | |
shared.state.assign_current_image(image) | |
else: | |
shared.state.current_image = image | |
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, | |
shared.opts.samples_format, | |
processed.infotexts[0], p=p, | |
forced_filename=forced_filename, | |
save_to_dirs=False) | |
last_saved_image += f", prompt: {preview_text}" | |
set_accessible(hypernetwork) | |
shared.state.job_no = hypernetwork.step | |
shared.state.textinfo = f""" | |
<p> | |
Loss: {loss_step:.7f}<br/> | |
Step: {steps_done}<br/> | |
Last prompt: {html.escape(batch.cond_text[0])}<br/> | |
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> | |
Last saved image: {html.escape(last_saved_image)}<br/> | |
</p> | |
""" | |
except Exception: | |
print(traceback.format_exc(), file=sys.stderr) | |
finally: | |
pbar.leave = False | |
pbar.close() | |
if hypernetwork is not None: | |
hypernetwork.eval() | |
shared.parallel_processing_allowed = old_parallel_processing_allowed | |
if hasattr(sd_hijack_checkpoint, 'remove'): | |
sd_hijack_checkpoint.remove() | |
set_scheduler(-1, False, False) | |
remove_accessible() | |
gc.collect() | |
torch.cuda.empty_cache() | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') | |
hypernetwork.optimizer_name = optimizer_name | |
if shared.opts.save_optimizer_state: | |
hypernetwork.optimizer_state_dict = optimizer.state_dict() | |
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) | |
del optimizer | |
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. | |
shared.sd_model.cond_stage_model.to(devices.device) | |
shared.sd_model.first_stage_model.to(devices.device) | |
return hypernetwork, filename | |
def internal_clean_training(hypernetwork_name, data_root, log_directory, | |
create_image_every, save_hypernetwork_every, | |
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, | |
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height, | |
move_optimizer=True, | |
load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1, | |
setting_tuple=None): | |
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. | |
from modules import images | |
base_hypernetwork_name = hypernetwork_name | |
manual_seed = int(manual_dataset_seed) | |
if setting_tuple is not None: | |
setting_suffix = f"_{setting_tuple[0]}_{setting_tuple[1]}" | |
else: | |
setting_suffix = time.strftime('%Y%m%d%H%M%S') | |
if load_hypernetworks_option != '': | |
dump_hyper: dict = get_training_option(load_hypernetworks_option) | |
hypernetwork_name = hypernetwork_name + setting_suffix | |
enable_sizes = dump_hyper['enable_sizes'] | |
overwrite_old = dump_hyper['overwrite_old'] | |
layer_structure = dump_hyper['layer_structure'] | |
activation_func = dump_hyper['activation_func'] | |
weight_init = dump_hyper['weight_init'] | |
add_layer_norm = dump_hyper['add_layer_norm'] | |
use_dropout = dump_hyper['use_dropout'] | |
dropout_structure = dump_hyper['dropout_structure'] | |
optional_info = dump_hyper['optional_info'] | |
weight_init_seed = dump_hyper['weight_init_seed'] | |
normal_std = dump_hyper['normal_std'] | |
skip_connection = dump_hyper['skip_connection'] | |
hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure, | |
activation_func, weight_init, add_layer_norm, use_dropout, | |
dropout_structure, optional_info, weight_init_seed, normal_std, | |
skip_connection) | |
else: | |
hypernetwork = load_hypernetwork(hypernetwork_name) | |
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix | |
hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt")) | |
shared.reload_hypernetworks() | |
hypernetwork = load_hypernetwork(hypernetwork_name) | |
if load_training_options != '': | |
dump: dict = get_training_option(load_training_options) | |
if dump and dump is not None: | |
learn_rate = dump['learn_rate'] | |
batch_size = dump['batch_size'] | |
gradient_step = dump['gradient_step'] | |
training_width = dump['training_width'] | |
training_height = dump['training_height'] | |
steps = dump['steps'] | |
shuffle_tags = dump['shuffle_tags'] | |
tag_drop_out = dump['tag_drop_out'] | |
save_when_converge = dump['save_when_converge'] | |
create_when_converge = dump['create_when_converge'] | |
latent_sampling_method = dump['latent_sampling_method'] | |
template_file = dump['template_file'] | |
use_beta_scheduler = dump['use_beta_scheduler'] | |
beta_repeat_epoch = dump['beta_repeat_epoch'] | |
epoch_mult = dump['epoch_mult'] | |
warmup = dump['warmup'] | |
min_lr = dump['min_lr'] | |
gamma_rate = dump['gamma_rate'] | |
use_adamw_parameter = dump['use_beta_adamW_checkbox'] | |
adamw_weight_decay = dump['adamw_weight_decay'] | |
adamw_beta_1 = dump['adamw_beta_1'] | |
adamw_beta_2 = dump['adamw_beta_2'] | |
adamw_eps = dump['adamw_eps'] | |
use_grad_opts = dump['show_gradient_clip_checkbox'] | |
gradient_clip_opt = dump['gradient_clip_opt'] | |
optional_gradient_clip_value = dump['optional_gradient_clip_value'] | |
optional_gradient_norm_type = dump['optional_gradient_norm_type'] | |
latent_sampling_std = dump.get('latent_sampling_std', -1) | |
noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) | |
noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) | |
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) | |
loss_opt = dump.get('loss_opt', 'loss_simple') | |
use_dadaptation = dump.get('use_dadaptation', False) | |
dadapt_growth_factor = dump.get('dadapt_growth_factor', -1) | |
use_weight = dump.get('use_weight', False) | |
else: | |
raise RuntimeError(f"Cannot load from {load_training_options}!") | |
else: | |
raise RuntimeError(f"Cannot load from {load_training_options}!") | |
try: | |
if use_adamw_parameter: | |
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in | |
[adamw_weight_decay, adamw_beta_1, | |
adamw_beta_2, adamw_eps]] | |
assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!" | |
assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, | |
adamw_eps])), "Cannot use negative or >1 number for adamW parameters!" | |
adamW_kwarg_dict = { | |
'weight_decay': adamw_weight_decay, | |
'betas': (adamw_beta_1, adamw_beta_2), | |
'eps': adamw_eps | |
} | |
print('Using custom AdamW parameters') | |
else: | |
adamW_kwarg_dict = { | |
'weight_decay': 0.01, | |
'betas': (0.9, 0.99), | |
'eps': 1e-8 | |
} | |
if use_beta_scheduler: | |
print("Using Beta Scheduler") | |
beta_repeat_epoch = int(float(beta_repeat_epoch)) | |
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!" | |
min_lr = float(min_lr) | |
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!" | |
gamma_rate = float(gamma_rate) | |
print(f"Using learn rate decay(per cycle) of {gamma_rate}") | |
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!" | |
epoch_mult = float(epoch_mult) | |
assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!" | |
warmup = int(float(warmup)) | |
assert warmup >= 1, "Warmup epoch should be larger than 0!" | |
print(f"Save when converges : {save_when_converge}") | |
print(f"Generate image when converges : {create_when_converge}") | |
else: | |
beta_repeat_epoch = 4000 | |
epoch_mult = 1 | |
warmup = 10 | |
min_lr = 1e-7 | |
gamma_rate = 1 | |
save_when_converge = False | |
create_when_converge = False | |
except ValueError: | |
raise RuntimeError("Cannot use advanced LR scheduler settings!") | |
if use_grad_opts and gradient_clip_opt != "None": | |
try: | |
optional_gradient_clip_value = float(optional_gradient_clip_value) | |
except ValueError: | |
raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})") | |
if gradient_clip_opt == "Norm": | |
try: | |
grad_norm = int(float(optional_gradient_norm_type)) | |
except ValueError: | |
raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})") | |
assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}" | |
print( | |
f"Using gradient clipping by Norm, norm type {optional_gradient_norm_type}, norm limit {optional_gradient_clip_value}") | |
def gradient_clipping(arg1): | |
torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type) | |
return | |
else: | |
print(f"Using gradient clipping by Value, limit {optional_gradient_clip_value}") | |
def gradient_clipping(arg1): | |
torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value) | |
return | |
else: | |
def gradient_clipping(arg1): | |
return | |
if noise_training_scheduler_enabled: | |
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) | |
print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!") | |
else: | |
set_scheduler(-1, False, False) | |
save_hypernetwork_every = save_hypernetwork_every or 0 | |
create_image_every = create_image_every or 0 | |
if not os.path.isfile(template_file): | |
template_file = textual_inversion.textual_inversion_templates.get(template_file, None) | |
if template_file is not None: | |
template_file = template_file.path | |
else: | |
raise AssertionError(f"Cannot find {template_file}!") | |
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") | |
hypernetwork.to(devices.device) | |
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!" | |
if not isinstance(hypernetwork, Hypernetwork): | |
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") | |
set_accessible(hypernetwork) | |
shared.state.job = "train-hypernetwork" | |
shared.state.textinfo = "Initializing hypernetwork training..." | |
shared.state.job_count = steps | |
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') | |
base_log_directory = log_directory | |
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) | |
unload = shared.opts.unload_models_when_training | |
if save_hypernetwork_every > 0 or save_when_converge: | |
hypernetwork_dir = os.path.join(log_directory, "hypernetworks") | |
os.makedirs(hypernetwork_dir, exist_ok=True) | |
else: | |
hypernetwork_dir = None | |
if create_image_every > 0 or create_when_converge: | |
images_dir = os.path.join(log_directory, "images") | |
os.makedirs(images_dir, exist_ok=True) | |
else: | |
images_dir = None | |
checkpoint = sd_models.select_checkpoint() | |
initial_step = hypernetwork.step or 0 | |
if initial_step >= steps: | |
shared.state.textinfo = f"Model has already been trained beyond specified max steps" | |
return hypernetwork, filename | |
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) | |
if shared.opts.training_enable_tensorboard: | |
print("Tensorboard logging enabled") | |
tensorboard_writer = tensorboard_setup(os.path.join(base_log_directory, base_hypernetwork_name)) | |
else: | |
tensorboard_writer = None | |
# dataset loading may take a while, so input validations and early returns should be done before this | |
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." | |
detach_grad = shared.opts.disable_ema # test code that removes EMA | |
if detach_grad: | |
print("Disabling training for staged models!") | |
shared.sd_model.cond_stage_model.requires_grad_(False) | |
shared.sd_model.first_stage_model.requires_grad_(False) | |
torch.cuda.empty_cache() | |
pin_memory = shared.opts.pin_memory | |
ds = PersonalizedBase(data_root=data_root, width=training_width, | |
height=training_height, | |
repeats=shared.opts.training_image_repeats_per_epoch, | |
placeholder_token=hypernetwork_name, model=shared.sd_model, | |
cond_model=shared.sd_model.cond_stage_model, | |
device=devices.device, template_file=template_file, | |
include_cond=True, batch_size=batch_size, | |
gradient_step=gradient_step, shuffle_tags=shuffle_tags, | |
tag_drop_out=tag_drop_out, | |
latent_sampling_method=latent_sampling_method, | |
latent_sampling_std=latent_sampling_std, | |
manual_seed=manual_seed, | |
use_weight=use_weight) | |
latent_sampling_method = ds.latent_sampling_method | |
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, | |
batch_size=ds.batch_size, pin_memory=pin_memory) | |
old_parallel_processing_allowed = shared.parallel_processing_allowed | |
if unload: | |
shared.parallel_processing_allowed = False | |
shared.sd_model.cond_stage_model.to(devices.cpu) | |
shared.sd_model.first_stage_model.to(devices.cpu) | |
weights = hypernetwork.weights(True) | |
optimizer_name = hypernetwork.optimizer_name | |
if hypernetwork.optimizer_name == 'DAdaptAdamW': | |
use_dadaptation = True | |
optimizer = None | |
# Here we use optimizer from saved HN, or we can specify as UI option. | |
if hypernetwork.optimizer_name in optimizer_dict: | |
if use_adamw_parameter: | |
if hypernetwork.optimizer_name != 'AdamW' and hypernetwork.optimizer_name != 'DAdaptAdamW': | |
raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!") | |
if use_dadaptation: | |
from .dadapt_test.install import get_dadapt_adam | |
optim_class = get_dadapt_adam(hypernetwork.optimizer_name) | |
if optim_class != torch.optim.AdamW: | |
optimizer = optim_class(params=weights, lr=scheduler.learn_rate, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, decouple=True, **adamW_kwarg_dict) | |
else: | |
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict) | |
else: | |
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict) | |
else: | |
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) | |
optimizer_name = hypernetwork.optimizer_name | |
else: | |
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") | |
if use_dadaptation: | |
from .dadapt_test.install import get_dadapt_adam | |
optim_class = get_dadapt_adam(hypernetwork.optimizer_name) | |
if optim_class != torch.optim.AdamW: | |
optimizer = optim_class(params=weights, lr=scheduler.learn_rate, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, decouple=True, **adamW_kwarg_dict) | |
optimizer_name = 'DAdaptAdamW' | |
if optimizer is None: | |
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict) | |
optimizer_name = 'AdamW' | |
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. | |
try: | |
optimizer.load_state_dict(hypernetwork.optimizer_state_dict) | |
except RuntimeError as e: | |
print("Cannot resume from saved optimizer!") | |
print(e) | |
optim_to(optimizer, devices.device) | |
if use_beta_scheduler: | |
scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, | |
cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, | |
warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate) | |
scheduler_beta.last_epoch = hypernetwork.step - 1 | |
else: | |
scheduler_beta = None | |
for pg in optimizer.param_groups: | |
pg['lr'] = scheduler.learn_rate | |
scaler = torch.cuda.amp.GradScaler() | |
batch_size = ds.batch_size | |
gradient_step = ds.gradient_step | |
# n steps = batch_size * gradient_step * n image processed | |
steps_per_epoch = len(ds) // batch_size // gradient_step | |
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step | |
loss_step = 0 | |
_loss_step = 0 # internal | |
# size = len(ds.indexes) | |
loss_dict = defaultdict(lambda: deque(maxlen=1024)) | |
# losses = torch.zeros((size,)) | |
# previous_mean_losses = [0] | |
# previous_mean_loss = 0 | |
# print("Mean loss of {} elements".format(size)) | |
steps_without_grad = 0 | |
last_saved_file = "<none>" | |
last_saved_image = "<none>" | |
forced_filename = "<none>" | |
if hasattr(sd_hijack_checkpoint, 'add'): | |
sd_hijack_checkpoint.add() | |
pbar = tqdm.tqdm(total=steps - initial_step) | |
try: | |
for i in range((steps - initial_step) * gradient_step): | |
if scheduler.finished or hypernetwork.step > steps: | |
break | |
if shared.state.interrupted: | |
break | |
for j, batch in enumerate(dl): | |
# works as a drop_last=True for gradient accumulation | |
if j == max_steps_per_epoch: | |
break | |
if use_beta_scheduler: | |
scheduler_beta.step(hypernetwork.step) | |
else: | |
scheduler.apply(optimizer, hypernetwork.step) | |
if scheduler.finished: | |
break | |
if shared.state.interrupted: | |
break | |
with torch.autocast("cuda"): | |
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) | |
if use_weight: | |
w = batch.weight.to(devices.device, non_blocking=pin_memory) | |
if tag_drop_out != 0 or shuffle_tags: | |
shared.sd_model.cond_stage_model.to(devices.device) | |
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, | |
non_blocking=pin_memory) | |
shared.sd_model.cond_stage_model.to(devices.cpu) | |
else: | |
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) | |
if use_weight: | |
loss = shared.sd_model.weighted_forward(x, c, w)[0] | |
else: | |
_, losses = shared.sd_model.forward(x, c) | |
loss = losses['val/' + loss_opt] | |
for filenames in batch.filename: | |
loss_dict[filenames].append(loss.detach().item()) | |
loss /= gradient_step | |
del x | |
del c | |
_loss_step += loss.item() | |
scaler.scale(loss).backward() | |
batch.latent_sample.to(devices.cpu) | |
# go back until we reach gradient accumulation steps | |
if (j + 1) % gradient_step != 0: | |
continue | |
gradient_clipping(weights) | |
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") | |
# scaler.unscale_(optimizer) | |
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") | |
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) | |
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") | |
try: | |
scaler.step(optimizer) | |
except AssertionError: | |
optimizer.param_groups[0]['capturable'] = True | |
scaler.step(optimizer) | |
scaler.update() | |
hypernetwork.step += 1 | |
pbar.update() | |
optimizer.zero_grad(set_to_none=True) | |
loss_step = _loss_step | |
_loss_step = 0 | |
steps_done = hypernetwork.step + 1 | |
epoch_num = hypernetwork.step // steps_per_epoch | |
epoch_step = hypernetwork.step % steps_per_epoch | |
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}" | |
pbar.set_description(description) | |
if hypernetwork_dir is not None and ( | |
(use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and save_when_converge) or ( | |
save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)): | |
# Before saving, change name to match current checkpoint. | |
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' | |
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') | |
hypernetwork.optimizer_name = optimizer_name | |
if shared.opts.save_optimizer_state: | |
hypernetwork.optimizer_state_dict = optimizer.state_dict() | |
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) | |
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. | |
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, | |
{ | |
"loss": f"{loss_step:.7f}", | |
"learn_rate": get_lr_from_optimizer(optimizer) | |
}) | |
if shared.opts.training_enable_tensorboard: | |
epoch_num = hypernetwork.step // len(ds) | |
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 | |
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) | |
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, | |
learn_rate=scheduler.learn_rate if not use_beta_scheduler else | |
get_lr_from_optimizer(optimizer), epoch_num=epoch_num, base_name=hypernetwork_name) | |
if images_dir is not None and ( | |
use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( | |
create_image_every > 0 and steps_done % create_image_every == 0): | |
set_scheduler(-1, False, False) | |
forced_filename = f'{hypernetwork_name}-{steps_done}' | |
last_saved_image = os.path.join(images_dir, forced_filename) | |
rng_state = torch.get_rng_state() | |
cuda_rng_state = None | |
if torch.cuda.is_available(): | |
cuda_rng_state = torch.cuda.get_rng_state_all() | |
hypernetwork.eval() | |
if move_optimizer: | |
optim_to(optimizer, devices.cpu) | |
shared.sd_model.cond_stage_model.to(devices.device) | |
shared.sd_model.first_stage_model.to(devices.device) | |
p = processing.StableDiffusionProcessingTxt2Img( | |
sd_model=shared.sd_model, | |
do_not_save_grid=True, | |
do_not_save_samples=True, | |
) | |
if hasattr(p, 'disable_extra_networks'): | |
p.disable_extra_networks = True | |
is_patched = True | |
else: | |
is_patched = False | |
if preview_from_txt2img: | |
p.prompt = preview_prompt + (hypernetwork.extra_name() if not is_patched else "") | |
p.negative_prompt = preview_negative_prompt | |
p.steps = preview_steps | |
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name | |
p.cfg_scale = preview_cfg_scale | |
p.seed = preview_seed | |
p.width = preview_width | |
p.height = preview_height | |
else: | |
p.prompt = batch.cond_text[0] + (hypernetwork.extra_name() if not is_patched else "") | |
p.steps = 20 | |
p.width = training_width | |
p.height = training_height | |
preview_text = p.prompt | |
processed = processing.process_images(p) | |
image = processed.images[0] if len(processed.images) > 0 else None | |
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: | |
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, | |
hypernetwork.step, base_name=hypernetwork_name) | |
if unload: | |
shared.sd_model.cond_stage_model.to(devices.cpu) | |
shared.sd_model.first_stage_model.to(devices.cpu) | |
torch.set_rng_state(rng_state) | |
if torch.cuda.is_available(): | |
torch.cuda.set_rng_state_all(cuda_rng_state) | |
hypernetwork.train() | |
if move_optimizer: | |
optim_to(optimizer, devices.device) | |
if noise_training_scheduler_enabled: | |
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) | |
if image is not None: | |
if hasattr(shared.state, 'assign_current_image'): | |
shared.state.assign_current_image(image) | |
else: | |
shared.state.current_image = image | |
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, | |
shared.opts.samples_format, | |
processed.infotexts[0], p=p, | |
forced_filename=forced_filename, | |
save_to_dirs=False) | |
last_saved_image += f", prompt: {preview_text}" | |
set_accessible(hypernetwork) | |
shared.state.job_no = hypernetwork.step | |
shared.state.textinfo = f""" | |
<p> | |
Loss: {loss_step:.7f}<br/> | |
Step: {steps_done}<br/> | |
Last prompt: {html.escape(batch.cond_text[0])}<br/> | |
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> | |
Last saved image: {html.escape(last_saved_image)}<br/> | |
</p> | |
""" | |
except Exception: | |
if pbar is not None: | |
pbar.set_description(traceback.format_exc()) | |
shared.state.textinfo = traceback.format_exc() | |
print(traceback.format_exc(), file=sys.stderr) | |
finally: | |
pbar.leave = False | |
pbar.close() | |
hypernetwork.eval() | |
set_scheduler(-1, False, False) | |
shared.parallel_processing_allowed = old_parallel_processing_allowed | |
remove_accessible() | |
if hasattr(sd_hijack_checkpoint, 'remove'): | |
sd_hijack_checkpoint.remove() | |
if shared.opts.training_enable_tensorboard: | |
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) if sum(len(x) for x in loss_dict.values()) > 0 else 0 | |
tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate, | |
GA_steps=gradient_step, | |
batch_size=batch_size, | |
layer_structure=hypernetwork.layer_structure, | |
activation=hypernetwork.activation_func, | |
weight_init=hypernetwork.weight_init, | |
dropout_structure=hypernetwork.dropout_structure, | |
max_steps=steps, | |
latent_sampling_method=latent_sampling_method, | |
template=template_file, | |
CosineAnnealing=use_beta_scheduler, | |
beta_repeat_epoch=beta_repeat_epoch, | |
epoch_mult=epoch_mult, | |
warmup=warmup, | |
min_lr=min_lr, | |
gamma_rate=gamma_rate, | |
adamW_opts=use_adamw_parameter, | |
adamW_decay=adamw_weight_decay, | |
adamW_beta_1=adamw_beta_1, | |
adamW_beta_2=adamw_beta_2, | |
adamW_eps=adamw_eps, | |
gradient_clip=gradient_clip_opt, | |
gradient_clip_value=optional_gradient_clip_value, | |
gradient_clip_norm_type=optional_gradient_norm_type, | |
loss=mean_loss, | |
base_hypernetwork_name=hypernetwork_name | |
) | |
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') | |
hypernetwork.optimizer_name = optimizer_name | |
if shared.opts.save_optimizer_state: | |
hypernetwork.optimizer_state_dict = optimizer.state_dict() | |
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) | |
del optimizer | |
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. | |
shared.sd_model.cond_stage_model.to(devices.device) | |
shared.sd_model.first_stage_model.to(devices.device) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return hypernetwork, filename | |
def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directory, | |
create_image_every, save_hypernetwork_every, preview_from_txt2img, preview_prompt, | |
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, | |
preview_seed, | |
preview_width, preview_height, | |
move_optimizer=True, | |
optional_new_hypernetwork_name='', load_hypernetworks_options='', | |
load_training_options='', manual_dataset_seed=-1): | |
load_hypernetworks_options = load_hypernetworks_options.split(',') | |
load_training_options = load_training_options.split(',') | |
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. | |
for _i, load_hypernetworks_option in enumerate(load_hypernetworks_options): | |
load_hypernetworks_option = load_hypernetworks_option.strip(' ') | |
if load_hypernetworks_option != '' and get_training_option(load_hypernetworks_option) is False: | |
print(f"Cannot load from {load_hypernetworks_option}!") | |
continue | |
for _j, load_training_option in enumerate(load_training_options): | |
load_training_option = load_training_option.strip(' ') | |
if get_training_option(load_training_option) is False: | |
print(f"Cannot load from {load_training_option}!") | |
continue | |
internal_clean_training( | |
hypernetwork_name if load_hypernetworks_option == '' else optional_new_hypernetwork_name, | |
data_root, | |
log_directory, | |
create_image_every, | |
save_hypernetwork_every, | |
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, | |
preview_cfg_scale, preview_seed, preview_width, preview_height, | |
move_optimizer, | |
load_hypernetworks_option, load_training_option, manual_dataset_seed, setting_tuple=(_i, _j)) | |
if shared.state.interrupted: | |
return None, None | |