hysts's picture
hysts HF staff
Update to Diffusers
db1e5fb
raw history blame
No virus
4.58 kB
from __future__ import annotations
import datetime
import os
import pathlib
import shlex
import shutil
import subprocess
import gradio as gr
import PIL.Image
import slugify
import torch
from constants import UploadTarget
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
if w == h:
return image
elif w > h:
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
class Trainer:
def prepare_dataset(self, instance_images: list, resolution: int,
instance_data_dir: pathlib.Path) -> None:
shutil.rmtree(instance_data_dir, ignore_errors=True)
instance_data_dir.mkdir(parents=True)
for i, temp_path in enumerate(instance_images):
image = PIL.Image.open(temp_path.name)
image = pad_image(image)
image = image.resize((resolution, resolution))
image = image.convert('RGB')
out_path = instance_data_dir / f'{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
def run(
self,
instance_images: list | None,
instance_prompt: str,
output_model_name: str,
overwrite_existing_model: bool,
validation_prompt: str,
base_model: str,
resolution_s: str,
n_steps: int,
learning_rate: float,
gradient_accumulation: int,
seed: int,
fp16: bool,
use_8bit_adam: bool,
checkpointing_steps: int,
use_wandb: bool,
validation_epochs: int,
upload_to_hub: bool,
use_private_repo: bool,
delete_existing_repo: bool,
upload_to: str,
) -> str:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
if instance_images is None:
raise gr.Error('You need to upload images.')
if not instance_prompt:
raise gr.Error('The instance prompt is missing.')
if not validation_prompt:
raise gr.Error('The validation prompt is missing.')
resolution = int(resolution_s)
if not output_model_name:
output_model_name = datetime.datetime.now().strftime(
'%Y-%m-%d-%H-%M-%S')
output_model_name = slugify.slugify(output_model_name)
repo_dir = pathlib.Path(__file__).parent
output_dir = repo_dir / 'experiments' / output_model_name
if overwrite_existing_model or upload_to_hub:
shutil.rmtree(output_dir, ignore_errors=True)
if not upload_to_hub:
output_dir.mkdir(parents=True)
instance_data_dir = repo_dir / 'training_data' / output_model_name
self.prepare_dataset(instance_images, resolution, instance_data_dir)
command = f'''
accelerate launch train_dreambooth_lora.py \
--pretrained_model_name_or_path={base_model} \
--instance_data_dir={instance_data_dir} \
--output_dir={output_dir} \
--instance_prompt="{instance_prompt}" \
--resolution={resolution} \
--train_batch_size=1 \
--gradient_accumulation_steps={gradient_accumulation} \
--learning_rate={learning_rate} \
--lr_scheduler=constant \
--lr_warmup_steps=0 \
--max_train_steps={n_steps} \
--checkpointing_steps={checkpointing_steps} \
--validation_prompt="{validation_prompt}" \
--validation_epochs={validation_epochs} \
--seed={seed}
'''
if fp16:
command += ' --mixed_precision fp16'
if use_8bit_adam:
command += ' --use_8bit_adam'
if use_wandb:
command += ' --report_to wandb'
if upload_to_hub:
hf_token = os.getenv('HF_TOKEN')
command += f' --push_to_hub --hub_token {hf_token}'
if use_private_repo:
command += ' --private_repo'
if delete_existing_repo:
command += ' --delete_existing_repo'
if upload_to == UploadTarget.LORA_LIBRARY.value:
command += ' --upload_to_lora_library'
subprocess.run(shlex.split(command))
with open(output_dir / 'train.sh', 'w') as f:
command_s = ' '.join(command.split())
f.write(command_s)
return 'Training completed!'