|
from tkinter import filedialog, Tk |
|
from easygui import msgbox |
|
import os |
|
import re |
|
import gradio as gr |
|
import easygui |
|
import shutil |
|
import sys |
|
import json |
|
|
|
from library.custom_logging import setup_logging |
|
from datetime import datetime |
|
|
|
|
|
log = setup_logging() |
|
|
|
folder_symbol = '\U0001f4c2' |
|
refresh_symbol = '\U0001f504' |
|
save_style_symbol = '\U0001f4be' |
|
document_symbol = '\U0001F4C4' |
|
|
|
|
|
V2_BASE_MODELS = [ |
|
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', |
|
'stabilityai/stable-diffusion-2-1-base', |
|
'stabilityai/stable-diffusion-2-base', |
|
] |
|
|
|
|
|
V_PARAMETERIZATION_MODELS = [ |
|
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', |
|
'stabilityai/stable-diffusion-2-1', |
|
'stabilityai/stable-diffusion-2', |
|
] |
|
|
|
|
|
V1_MODELS = [ |
|
'CompVis/stable-diffusion-v1-4', |
|
'runwayml/stable-diffusion-v1-5', |
|
] |
|
|
|
|
|
SDXL_MODELS = [ |
|
'stabilityai/stable-diffusion-xl-base-0.9', |
|
'stabilityai/stable-diffusion-xl-refiner-0.9' |
|
] |
|
|
|
|
|
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDXL_MODELS |
|
|
|
ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] |
|
|
|
|
|
def check_if_model_exist( |
|
output_name, output_dir, save_model_as, headless=False |
|
): |
|
if headless: |
|
log.info( |
|
'Headless mode, skipping verification if model already exist... if model already exist it will be overwritten...' |
|
) |
|
return False |
|
|
|
if save_model_as in ['diffusers', 'diffusers_safetendors']: |
|
ckpt_folder = os.path.join(output_dir, output_name) |
|
if os.path.isdir(ckpt_folder): |
|
msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?' |
|
if not easygui.ynbox(msg, 'Overwrite Existing Model?'): |
|
log.info( |
|
'Aborting training due to existing model with same name...' |
|
) |
|
return True |
|
elif save_model_as in ['ckpt', 'safetensors']: |
|
ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as) |
|
if os.path.isfile(ckpt_file): |
|
msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?' |
|
if not easygui.ynbox(msg, 'Overwrite Existing Model?'): |
|
log.info( |
|
'Aborting training due to existing model with same name...' |
|
) |
|
return True |
|
else: |
|
log.info( |
|
'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...' |
|
) |
|
return False |
|
|
|
return False |
|
|
|
|
|
def output_message(msg='', title='', headless=False): |
|
if headless: |
|
log.info(msg) |
|
else: |
|
msgbox(msg=msg, title=title) |
|
|
|
|
|
def update_my_data(my_data): |
|
|
|
use_8bit_adam = my_data.get('use_8bit_adam', False) |
|
my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW') |
|
|
|
|
|
model_list = my_data.get('model_list', []) |
|
pretrained_model_name_or_path = my_data.get( |
|
'pretrained_model_name_or_path', '' |
|
) |
|
if ( |
|
not model_list |
|
or pretrained_model_name_or_path not in ALL_PRESET_MODELS |
|
): |
|
my_data['model_list'] = 'custom' |
|
|
|
|
|
for key in ['epoch', 'save_every_n_epochs', 'lr_warmup']: |
|
value = my_data.get(key, 0) |
|
if isinstance(value, str) and value.strip().isdigit(): |
|
my_data[key] = int(value) |
|
elif not value: |
|
my_data[key] = 0 |
|
|
|
|
|
for key in ['noise_offset', 'learning_rate', 'text_encoder_lr', 'unet_lr']: |
|
value = my_data.get(key, 0) |
|
if isinstance(value, str) and value.strip().isdigit(): |
|
my_data[key] = float(value) |
|
elif not value: |
|
my_data[key] = 0 |
|
|
|
|
|
if my_data.get('LoRA_type', 'Standard') == 'LoCon': |
|
my_data['LoRA_type'] = 'LyCORIS/LoCon' |
|
|
|
|
|
if 'save_model_as' in my_data: |
|
if ( |
|
my_data.get('LoRA_type') or my_data.get('num_vectors_per_token') |
|
) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']: |
|
message = 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' |
|
if my_data.get('LoRA_type'): |
|
log.info(message.format('LoRA')) |
|
if my_data.get('num_vectors_per_token'): |
|
log.info(message.format('TI')) |
|
my_data['save_model_as'] = 'safetensors' |
|
|
|
return my_data |
|
|
|
|
|
def get_dir_and_file(file_path): |
|
dir_path, file_name = os.path.split(file_path) |
|
return (dir_path, file_name) |
|
|
|
|
|
def get_file_path( |
|
file_path='', default_extension='.json', extension_name='Config files' |
|
): |
|
if ( |
|
not any(var in os.environ for var in ENV_EXCLUSION) |
|
and sys.platform != 'darwin' |
|
): |
|
current_file_path = file_path |
|
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path) |
|
|
|
|
|
root = Tk() |
|
root.wm_attributes('-topmost', 1) |
|
root.withdraw() |
|
|
|
|
|
file_path = filedialog.askopenfilename( |
|
filetypes=( |
|
(extension_name, f'*{default_extension}'), |
|
('All files', '*.*'), |
|
), |
|
defaultextension=default_extension, |
|
initialfile=initial_file, |
|
initialdir=initial_dir, |
|
) |
|
|
|
|
|
root.destroy() |
|
|
|
|
|
if not file_path: |
|
file_path = current_file_path |
|
current_file_path = file_path |
|
|
|
|
|
return file_path |
|
|
|
|
|
def get_any_file_path(file_path=''): |
|
if ( |
|
not any(var in os.environ for var in ENV_EXCLUSION) |
|
and sys.platform != 'darwin' |
|
): |
|
current_file_path = file_path |
|
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path) |
|
|
|
root = Tk() |
|
root.wm_attributes('-topmost', 1) |
|
root.withdraw() |
|
file_path = filedialog.askopenfilename( |
|
initialdir=initial_dir, |
|
initialfile=initial_file, |
|
) |
|
root.destroy() |
|
|
|
if file_path == '': |
|
file_path = current_file_path |
|
|
|
return file_path |
|
|
|
|
|
def remove_doublequote(file_path): |
|
if file_path != None: |
|
file_path = file_path.replace('"', '') |
|
|
|
return file_path |
|
|
|
|
|
def get_folder_path(folder_path=''): |
|
if ( |
|
not any(var in os.environ for var in ENV_EXCLUSION) |
|
and sys.platform != 'darwin' |
|
): |
|
current_folder_path = folder_path |
|
|
|
initial_dir, initial_file = get_dir_and_file(folder_path) |
|
|
|
root = Tk() |
|
root.wm_attributes('-topmost', 1) |
|
root.withdraw() |
|
folder_path = filedialog.askdirectory(initialdir=initial_dir) |
|
root.destroy() |
|
|
|
if folder_path == '': |
|
folder_path = current_folder_path |
|
|
|
return folder_path |
|
|
|
|
|
def get_saveasfile_path( |
|
file_path='', defaultextension='.json', extension_name='Config files' |
|
): |
|
if ( |
|
not any(var in os.environ for var in ENV_EXCLUSION) |
|
and sys.platform != 'darwin' |
|
): |
|
current_file_path = file_path |
|
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path) |
|
|
|
root = Tk() |
|
root.wm_attributes('-topmost', 1) |
|
root.withdraw() |
|
save_file_path = filedialog.asksaveasfile( |
|
filetypes=( |
|
(f'{extension_name}', f'{defaultextension}'), |
|
('All files', '*'), |
|
), |
|
defaultextension=defaultextension, |
|
initialdir=initial_dir, |
|
initialfile=initial_file, |
|
) |
|
root.destroy() |
|
|
|
|
|
|
|
if save_file_path == None: |
|
file_path = current_file_path |
|
else: |
|
log.info(save_file_path.name) |
|
file_path = save_file_path.name |
|
|
|
|
|
|
|
return file_path |
|
|
|
|
|
def get_saveasfilename_path( |
|
file_path='', extensions='*', extension_name='Config files' |
|
): |
|
if ( |
|
not any(var in os.environ for var in ENV_EXCLUSION) |
|
and sys.platform != 'darwin' |
|
): |
|
current_file_path = file_path |
|
|
|
|
|
initial_dir, initial_file = get_dir_and_file(file_path) |
|
|
|
root = Tk() |
|
root.wm_attributes('-topmost', 1) |
|
root.withdraw() |
|
save_file_path = filedialog.asksaveasfilename( |
|
filetypes=( |
|
(f'{extension_name}', f'{extensions}'), |
|
('All files', '*'), |
|
), |
|
defaultextension=extensions, |
|
initialdir=initial_dir, |
|
initialfile=initial_file, |
|
) |
|
root.destroy() |
|
|
|
if save_file_path == '': |
|
file_path = current_file_path |
|
else: |
|
|
|
file_path = save_file_path |
|
|
|
return file_path |
|
|
|
|
|
def add_pre_postfix( |
|
folder: str = '', |
|
prefix: str = '', |
|
postfix: str = '', |
|
caption_file_ext: str = '.caption', |
|
) -> None: |
|
""" |
|
Add prefix and/or postfix to the content of caption files within a folder. |
|
If no caption files are found, create one with the requested prefix and/or postfix. |
|
|
|
Args: |
|
folder (str): Path to the folder containing caption files. |
|
prefix (str, optional): Prefix to add to the content of the caption files. |
|
postfix (str, optional): Postfix to add to the content of the caption files. |
|
caption_file_ext (str, optional): Extension of the caption files. |
|
""" |
|
|
|
if prefix == '' and postfix == '': |
|
return |
|
|
|
image_extensions = ('.jpg', '.jpeg', '.png', '.webp') |
|
image_files = [ |
|
f for f in os.listdir(folder) if f.lower().endswith(image_extensions) |
|
] |
|
|
|
for image_file in image_files: |
|
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext |
|
caption_file_path = os.path.join(folder, caption_file_name) |
|
|
|
if not os.path.exists(caption_file_path): |
|
with open(caption_file_path, 'w', encoding='utf8') as f: |
|
separator = ' ' if prefix and postfix else '' |
|
f.write(f'{prefix}{separator}{postfix}') |
|
else: |
|
with open(caption_file_path, 'r+', encoding='utf8') as f: |
|
content = f.read() |
|
content = content.rstrip() |
|
f.seek(0, 0) |
|
|
|
prefix_separator = ' ' if prefix else '' |
|
postfix_separator = ' ' if postfix else '' |
|
f.write( |
|
f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}' |
|
) |
|
|
|
|
|
def has_ext_files(folder_path: str, file_extension: str) -> bool: |
|
""" |
|
Check if there are any files with the specified extension in the given folder. |
|
|
|
Args: |
|
folder_path (str): Path to the folder containing files. |
|
file_extension (str): Extension of the files to look for. |
|
|
|
Returns: |
|
bool: True if files with the specified extension are found, False otherwise. |
|
""" |
|
for file in os.listdir(folder_path): |
|
if file.endswith(file_extension): |
|
return True |
|
return False |
|
|
|
|
|
def find_replace( |
|
folder_path: str = '', |
|
caption_file_ext: str = '.caption', |
|
search_text: str = '', |
|
replace_text: str = '', |
|
) -> None: |
|
""" |
|
Find and replace text in caption files within a folder. |
|
|
|
Args: |
|
folder_path (str, optional): Path to the folder containing caption files. |
|
caption_file_ext (str, optional): Extension of the caption files. |
|
search_text (str, optional): Text to search for in the caption files. |
|
replace_text (str, optional): Text to replace the search text with. |
|
""" |
|
log.info('Running caption find/replace') |
|
|
|
if not has_ext_files(folder_path, caption_file_ext): |
|
msgbox( |
|
f'No files with extension {caption_file_ext} were found in {folder_path}...' |
|
) |
|
return |
|
|
|
if search_text == '': |
|
return |
|
|
|
caption_files = [ |
|
f for f in os.listdir(folder_path) if f.endswith(caption_file_ext) |
|
] |
|
|
|
for caption_file in caption_files: |
|
with open( |
|
os.path.join(folder_path, caption_file), 'r', errors='ignore' |
|
) as f: |
|
content = f.read() |
|
|
|
content = content.replace(search_text, replace_text) |
|
|
|
with open(os.path.join(folder_path, caption_file), 'w') as f: |
|
f.write(content) |
|
|
|
|
|
def color_aug_changed(color_aug): |
|
if color_aug: |
|
msgbox( |
|
'Disabling "Cache latent" because "Color augmentation" has been selected...' |
|
) |
|
return gr.Checkbox.update(value=False, interactive=False) |
|
else: |
|
return gr.Checkbox.update(value=True, interactive=True) |
|
|
|
|
|
def save_inference_file(output_dir, v2, v_parameterization, output_name): |
|
|
|
files = os.listdir(output_dir) |
|
|
|
|
|
for file in files: |
|
|
|
if file.startswith(output_name): |
|
|
|
if os.path.isfile(os.path.join(output_dir, file)): |
|
|
|
file_name, ext = os.path.splitext(file) |
|
|
|
|
|
if v2 and v_parameterization: |
|
log.info( |
|
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml' |
|
) |
|
shutil.copy( |
|
f'./v2_inference/v2-inference-v.yaml', |
|
f'{output_dir}/{file_name}.yaml', |
|
) |
|
elif v2: |
|
log.info( |
|
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml' |
|
) |
|
shutil.copy( |
|
f'./v2_inference/v2-inference.yaml', |
|
f'{output_dir}/{file_name}.yaml', |
|
) |
|
|
|
|
|
def set_pretrained_model_name_or_path_input( |
|
model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl |
|
): |
|
|
|
if str(model_list) in SDXL_MODELS: |
|
log.info('SDXL model selected. Setting sdxl parameters') |
|
v2 = gr.Checkbox.update(value=False, visible=False) |
|
v_parameterization = gr.Checkbox.update(value=False, visible=False) |
|
sdxl = gr.Checkbox.update(value=True, visible=False) |
|
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) |
|
pretrained_model_name_or_path_file = gr.Button.update(visible=False) |
|
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) |
|
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl |
|
|
|
|
|
if str(model_list) in V2_BASE_MODELS: |
|
log.info('SD v2 base model selected. Setting --v2 parameter') |
|
v2 = gr.Checkbox.update(value=True, visible=False) |
|
v_parameterization = gr.Checkbox.update(value=False, visible=False) |
|
sdxl = gr.Checkbox.update(value=False, visible=False) |
|
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) |
|
pretrained_model_name_or_path_file = gr.Button.update(visible=False) |
|
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) |
|
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl |
|
|
|
|
|
if str(model_list) in V_PARAMETERIZATION_MODELS: |
|
log.info( |
|
'SD v2 model selected. Setting --v2 and --v_parameterization parameters' |
|
) |
|
v2 = gr.Checkbox.update(value=True, visible=False) |
|
v_parameterization = gr.Checkbox.update(value=True, visible=False) |
|
sdxl = gr.Checkbox.update(value=False, visible=False) |
|
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) |
|
pretrained_model_name_or_path_file = gr.Button.update(visible=False) |
|
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) |
|
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl |
|
|
|
|
|
if str(model_list) in V1_MODELS: |
|
log.info( |
|
'SD v1.4 model selected.' |
|
) |
|
v2 = gr.Checkbox.update(value=False, visible=False) |
|
v_parameterization = gr.Checkbox.update(value=False, visible=False) |
|
sdxl = gr.Checkbox.update(value=False, visible=False) |
|
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) |
|
pretrained_model_name_or_path_file = gr.Button.update(visible=False) |
|
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) |
|
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl |
|
|
|
|
|
if model_list == 'custom': |
|
v2 = gr.Checkbox.update(visible=True) |
|
v_parameterization = gr.Checkbox.update(visible=True) |
|
sdxl = gr.Checkbox.update(visible=True) |
|
pretrained_model_name_or_path = gr.Textbox.update(visible=True) |
|
pretrained_model_name_or_path_file = gr.Button.update(visible=True) |
|
pretrained_model_name_or_path_folder = gr.Button.update(visible=True) |
|
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_pretrained_model_name_or_path_file( |
|
model_list, pretrained_model_name_or_path |
|
): |
|
pretrained_model_name_or_path = get_any_file_path( |
|
pretrained_model_name_or_path |
|
) |
|
|
|
|
|
|
|
def get_int_or_default(kwargs, key, default_value=0): |
|
value = kwargs.get(key, default_value) |
|
if isinstance(value, int): |
|
return value |
|
elif isinstance(value, str): |
|
return int(value) |
|
elif isinstance(value, float): |
|
return int(value) |
|
else: |
|
log.info(f'{key} is not an int, float or a string, setting value to {default_value}') |
|
return default_value |
|
|
|
def get_float_or_default(kwargs, key, default_value=0.0): |
|
value = kwargs.get(key, default_value) |
|
if isinstance(value, float): |
|
return value |
|
elif isinstance(value, int): |
|
return float(value) |
|
elif isinstance(value, str): |
|
return float(value) |
|
else: |
|
log.info(f'{key} is not an int, float or a string, setting value to {default_value}') |
|
return default_value |
|
|
|
def get_str_or_default(kwargs, key, default_value=""): |
|
value = kwargs.get(key, default_value) |
|
if isinstance(value, str): |
|
return value |
|
elif isinstance(value, int): |
|
return str(value) |
|
elif isinstance(value, str): |
|
return str(value) |
|
else: |
|
return default_value |
|
|
|
def run_cmd_training(**kwargs): |
|
run_cmd = '' |
|
|
|
learning_rate = kwargs.get("learning_rate", "") |
|
if learning_rate: |
|
run_cmd += f' --learning_rate="{learning_rate}"' |
|
|
|
lr_scheduler = kwargs.get("lr_scheduler", "") |
|
if lr_scheduler: |
|
run_cmd += f' --lr_scheduler="{lr_scheduler}"' |
|
|
|
lr_warmup_steps = kwargs.get("lr_warmup_steps", "") |
|
if lr_warmup_steps: |
|
if lr_scheduler == 'constant': |
|
log.info('Can\'t use LR warmup with LR Scheduler constant... ignoring...') |
|
else: |
|
run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"' |
|
|
|
train_batch_size = kwargs.get("train_batch_size", "") |
|
if train_batch_size: |
|
run_cmd += f' --train_batch_size="{train_batch_size}"' |
|
|
|
max_train_steps = kwargs.get("max_train_steps", "") |
|
if max_train_steps: |
|
run_cmd += f' --max_train_steps="{max_train_steps}"' |
|
|
|
save_every_n_epochs = kwargs.get("save_every_n_epochs") |
|
if save_every_n_epochs: |
|
run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"' |
|
|
|
mixed_precision = kwargs.get("mixed_precision", "") |
|
if mixed_precision: |
|
run_cmd += f' --mixed_precision="{mixed_precision}"' |
|
|
|
save_precision = kwargs.get("save_precision", "") |
|
if save_precision: |
|
run_cmd += f' --save_precision="{save_precision}"' |
|
|
|
seed = kwargs.get("seed", "") |
|
if seed != '': |
|
run_cmd += f' --seed="{seed}"' |
|
|
|
caption_extension = kwargs.get("caption_extension", "") |
|
if caption_extension: |
|
run_cmd += f' --caption_extension="{caption_extension}"' |
|
|
|
cache_latents = kwargs.get('cache_latents') |
|
if cache_latents: |
|
run_cmd += ' --cache_latents' |
|
|
|
cache_latents_to_disk = kwargs.get('cache_latents_to_disk') |
|
if cache_latents_to_disk: |
|
run_cmd += ' --cache_latents_to_disk' |
|
|
|
optimizer_type = kwargs.get("optimizer", "AdamW") |
|
run_cmd += f' --optimizer_type="{optimizer_type}"' |
|
|
|
optimizer_args = kwargs.get("optimizer_args", "") |
|
if optimizer_args != '': |
|
run_cmd += f' --optimizer_args {optimizer_args}' |
|
|
|
return run_cmd |
|
|
|
|
|
def run_cmd_advanced_training(**kwargs): |
|
run_cmd = '' |
|
|
|
max_train_epochs = kwargs.get("max_train_epochs", "") |
|
if max_train_epochs: |
|
run_cmd += f' --max_train_epochs={max_train_epochs}' |
|
|
|
max_data_loader_n_workers = kwargs.get("max_data_loader_n_workers", "") |
|
if max_data_loader_n_workers: |
|
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' |
|
|
|
max_token_length = int(kwargs.get("max_token_length", 75)) |
|
if max_token_length > 75: |
|
run_cmd += f' --max_token_length={max_token_length}' |
|
|
|
clip_skip = int(kwargs.get("clip_skip", 1)) |
|
if clip_skip > 1: |
|
run_cmd += f' --clip_skip={clip_skip}' |
|
|
|
resume = kwargs.get("resume", "") |
|
if resume: |
|
run_cmd += f' --resume="{resume}"' |
|
|
|
keep_tokens = int(kwargs.get("keep_tokens", 0)) |
|
if keep_tokens > 0: |
|
run_cmd += f' --keep_tokens="{keep_tokens}"' |
|
|
|
caption_dropout_every_n_epochs = int(kwargs.get("caption_dropout_every_n_epochs", 0)) |
|
if caption_dropout_every_n_epochs > 0: |
|
run_cmd += f' --caption_dropout_every_n_epochs="{caption_dropout_every_n_epochs}"' |
|
|
|
caption_dropout_rate = float(kwargs.get("caption_dropout_rate", 0)) |
|
if caption_dropout_rate > 0: |
|
run_cmd += f' --caption_dropout_rate="{caption_dropout_rate}"' |
|
|
|
vae_batch_size = int(kwargs.get("vae_batch_size", 0)) |
|
if vae_batch_size > 0: |
|
run_cmd += f' --vae_batch_size="{vae_batch_size}"' |
|
|
|
bucket_reso_steps = int(kwargs.get("bucket_reso_steps", 64)) |
|
run_cmd += f' --bucket_reso_steps={bucket_reso_steps}' |
|
|
|
save_every_n_steps = int(kwargs.get("save_every_n_steps", 0)) |
|
if save_every_n_steps > 0: |
|
run_cmd += f' --save_every_n_steps="{save_every_n_steps}"' |
|
|
|
save_last_n_steps = int(kwargs.get("save_last_n_steps", 0)) |
|
if save_last_n_steps > 0: |
|
run_cmd += f' --save_last_n_steps="{save_last_n_steps}"' |
|
|
|
save_last_n_steps_state = int(kwargs.get("save_last_n_steps_state", 0)) |
|
if save_last_n_steps_state > 0: |
|
run_cmd += f' --save_last_n_steps_state="{save_last_n_steps_state}"' |
|
|
|
min_snr_gamma = int(kwargs.get("min_snr_gamma", 0)) |
|
if min_snr_gamma >= 1: |
|
run_cmd += f' --min_snr_gamma={min_snr_gamma}' |
|
|
|
min_timestep = int(kwargs.get("min_timestep", 0)) |
|
if min_timestep > 0: |
|
run_cmd += f' --min_timestep={min_timestep}' |
|
|
|
max_timestep = int(kwargs.get("max_timestep", 1000)) |
|
if max_timestep < 1000: |
|
run_cmd += f' --max_timestep={max_timestep}' |
|
|
|
save_state = kwargs.get('save_state') |
|
if save_state: |
|
run_cmd += ' --save_state' |
|
|
|
mem_eff_attn = kwargs.get('mem_eff_attn') |
|
if mem_eff_attn: |
|
run_cmd += ' --mem_eff_attn' |
|
|
|
color_aug = kwargs.get('color_aug') |
|
if color_aug: |
|
run_cmd += ' --color_aug' |
|
|
|
flip_aug = kwargs.get('flip_aug') |
|
if flip_aug: |
|
run_cmd += ' --flip_aug' |
|
|
|
shuffle_caption = kwargs.get('shuffle_caption') |
|
if shuffle_caption: |
|
run_cmd += ' --shuffle_caption' |
|
|
|
gradient_checkpointing = kwargs.get('gradient_checkpointing') |
|
if gradient_checkpointing: |
|
run_cmd += ' --gradient_checkpointing' |
|
|
|
full_fp16 = kwargs.get('full_fp16') |
|
if full_fp16: |
|
run_cmd += ' --full_fp16' |
|
|
|
xformers = kwargs.get('xformers') |
|
if xformers: |
|
run_cmd += ' --xformers' |
|
|
|
persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers') |
|
if persistent_data_loader_workers: |
|
run_cmd += ' --persistent_data_loader_workers' |
|
|
|
bucket_no_upscale = kwargs.get('bucket_no_upscale') |
|
if bucket_no_upscale: |
|
run_cmd += ' --bucket_no_upscale' |
|
|
|
random_crop = kwargs.get('random_crop') |
|
if random_crop: |
|
run_cmd += ' --random_crop' |
|
|
|
scale_v_pred_loss_like_noise_pred = kwargs.get('scale_v_pred_loss_like_noise_pred') |
|
if scale_v_pred_loss_like_noise_pred: |
|
run_cmd += ' --scale_v_pred_loss_like_noise_pred' |
|
|
|
noise_offset_type = kwargs.get('noise_offset_type', 'Original') |
|
if noise_offset_type == 'Original': |
|
noise_offset = float(kwargs.get("noise_offset", 0)) |
|
if noise_offset > 0: |
|
run_cmd += f' --noise_offset={noise_offset}' |
|
|
|
adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0)) |
|
if adaptive_noise_scale != 0 and noise_offset > 0: |
|
run_cmd += f' --adaptive_noise_scale={adaptive_noise_scale}' |
|
else: |
|
multires_noise_iterations = int(kwargs.get("multires_noise_iterations", 0)) |
|
if multires_noise_iterations > 0: |
|
run_cmd += f' --multires_noise_iterations="{multires_noise_iterations}"' |
|
|
|
multires_noise_discount = float(kwargs.get("multires_noise_discount", 0)) |
|
if multires_noise_discount > 0: |
|
run_cmd += f' --multires_noise_discount="{multires_noise_discount}"' |
|
|
|
additional_parameters = kwargs.get("additional_parameters", "") |
|
if additional_parameters: |
|
run_cmd += f' {additional_parameters}' |
|
|
|
use_wandb = kwargs.get('use_wandb') |
|
if use_wandb: |
|
run_cmd += ' --log_with wandb' |
|
|
|
wandb_api_key = kwargs.get("wandb_api_key", "") |
|
if wandb_api_key: |
|
run_cmd += f' --wandb_api_key="{wandb_api_key}"' |
|
|
|
return run_cmd |
|
|
|
def verify_image_folder_pattern(folder_path): |
|
false_response = True |
|
true_response = True |
|
|
|
|
|
if not os.path.isdir(folder_path): |
|
log.error(f"The provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ...") |
|
return false_response |
|
|
|
|
|
|
|
|
|
|
|
pattern = r'^\d+_\w+' |
|
|
|
|
|
subfolders = [ |
|
os.path.join(folder_path, subfolder) |
|
for subfolder in os.listdir(folder_path) |
|
if os.path.isdir(os.path.join(folder_path, subfolder)) |
|
] |
|
|
|
|
|
matching_subfolders = [subfolder for subfolder in subfolders if re.match(pattern, os.path.basename(subfolder))] |
|
|
|
|
|
non_matching_subfolders = set(subfolders) - set(matching_subfolders) |
|
if non_matching_subfolders: |
|
log.error(f"The following folders do not match the required pattern <number>_<text>: {', '.join(non_matching_subfolders)}") |
|
log.error(f"Please follow the folder structure documentation found at docs\image_folder_structure.md ...") |
|
return false_response |
|
|
|
|
|
if not matching_subfolders: |
|
log.error(f"No image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ...") |
|
return false_response |
|
|
|
log.info(f'Valid image folder names found in: {folder_path}') |
|
return true_response |
|
|
|
def SaveConfigFile(parameters, file_path: str, exclusion = ['file_path', 'save_as', 'headless', 'print_only']): |
|
|
|
variables = { |
|
name: value |
|
for name, value in sorted(parameters, key=lambda x: x[0]) |
|
if name not in exclusion |
|
} |
|
|
|
|
|
with open(file_path, 'w') as file: |
|
json.dump(variables, file, indent=2) |
|
|
|
def save_to_file(content): |
|
file_path = 'logs/print_command.txt' |
|
with open(file_path, 'a') as file: |
|
file.write(content + '\n') |