|
from __future__ import annotations |
|
|
|
import datetime |
|
import os |
|
import pathlib |
|
import shlex |
|
import shutil |
|
import subprocess |
|
|
|
import gradio as gr |
|
import PIL.Image |
|
import slugify |
|
import torch |
|
|
|
from constants import UploadTarget |
|
|
|
|
|
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: |
|
w, h = image.size |
|
if w == h: |
|
return image |
|
elif w > h: |
|
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) |
|
new_image.paste(image, (0, (w - h) // 2)) |
|
return new_image |
|
else: |
|
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) |
|
new_image.paste(image, ((h - w) // 2, 0)) |
|
return new_image |
|
|
|
|
|
class Trainer: |
|
def prepare_dataset(self, instance_images: list, resolution: int, |
|
instance_data_dir: pathlib.Path) -> None: |
|
shutil.rmtree(instance_data_dir, ignore_errors=True) |
|
instance_data_dir.mkdir(parents=True) |
|
for i, temp_path in enumerate(instance_images): |
|
image = PIL.Image.open(temp_path.name) |
|
image = pad_image(image) |
|
image = image.resize((resolution, resolution)) |
|
image = image.convert('RGB') |
|
out_path = instance_data_dir / f'{i:03d}.jpg' |
|
image.save(out_path, format='JPEG', quality=100) |
|
|
|
def run( |
|
self, |
|
instance_images: list | None, |
|
instance_prompt: str, |
|
output_model_name: str, |
|
overwrite_existing_model: bool, |
|
validation_prompt: str, |
|
base_model: str, |
|
resolution_s: str, |
|
n_steps: int, |
|
learning_rate: float, |
|
gradient_accumulation: int, |
|
seed: int, |
|
fp16: bool, |
|
use_8bit_adam: bool, |
|
checkpointing_steps: int, |
|
use_wandb: bool, |
|
validation_epochs: int, |
|
upload_to_hub: bool, |
|
use_private_repo: bool, |
|
delete_existing_repo: bool, |
|
upload_to: str, |
|
) -> str: |
|
if not torch.cuda.is_available(): |
|
raise gr.Error('CUDA is not available.') |
|
if instance_images is None: |
|
raise gr.Error('You need to upload images.') |
|
if not instance_prompt: |
|
raise gr.Error('The instance prompt is missing.') |
|
if not validation_prompt: |
|
raise gr.Error('The validation prompt is missing.') |
|
|
|
resolution = int(resolution_s) |
|
|
|
if not output_model_name: |
|
output_model_name = datetime.datetime.now().strftime( |
|
'%Y-%m-%d-%H-%M-%S') |
|
output_model_name = slugify.slugify(output_model_name) |
|
|
|
repo_dir = pathlib.Path(__file__).parent |
|
output_dir = repo_dir / 'experiments' / output_model_name |
|
if overwrite_existing_model or upload_to_hub: |
|
shutil.rmtree(output_dir, ignore_errors=True) |
|
if not upload_to_hub: |
|
output_dir.mkdir(parents=True) |
|
|
|
instance_data_dir = repo_dir / 'training_data' / output_model_name |
|
self.prepare_dataset(instance_images, resolution, instance_data_dir) |
|
|
|
command = f''' |
|
accelerate launch train_dreambooth_lora.py \ |
|
--pretrained_model_name_or_path={base_model} \ |
|
--instance_data_dir={instance_data_dir} \ |
|
--output_dir={output_dir} \ |
|
--instance_prompt="{instance_prompt}" \ |
|
--resolution={resolution} \ |
|
--train_batch_size=1 \ |
|
--gradient_accumulation_steps={gradient_accumulation} \ |
|
--learning_rate={learning_rate} \ |
|
--lr_scheduler=constant \ |
|
--lr_warmup_steps=0 \ |
|
--max_train_steps={n_steps} \ |
|
--checkpointing_steps={checkpointing_steps} \ |
|
--validation_prompt="{validation_prompt}" \ |
|
--validation_epochs={validation_epochs} \ |
|
--seed={seed} |
|
''' |
|
if fp16: |
|
command += ' --mixed_precision fp16' |
|
if use_8bit_adam: |
|
command += ' --use_8bit_adam' |
|
if use_wandb: |
|
command += ' --report_to wandb' |
|
if upload_to_hub: |
|
hf_token = os.getenv('HF_TOKEN') |
|
command += f' --push_to_hub --hub_token {hf_token}' |
|
if use_private_repo: |
|
command += ' --private_repo' |
|
if delete_existing_repo: |
|
command += ' --delete_existing_repo' |
|
if upload_to == UploadTarget.LORA_LIBRARY.value: |
|
command += ' --upload_to_lora_library' |
|
|
|
subprocess.run(shlex.split(command)) |
|
|
|
with open(output_dir / 'train.sh', 'w') as f: |
|
command_s = ' '.join(command.split()) |
|
f.write(command_s) |
|
|
|
return 'Training completed!' |
|
|