Spaces:
Runtime error
Runtime error
import glob | |
import json | |
import logging | |
import math | |
import os.path | |
from typing import Optional, Union | |
from gchar.games.base import Character | |
from hbutils.string import plural_word | |
from hbutils.system import TemporaryDirectory | |
from hcpdiff.train_ac import Trainer | |
from hcpdiff.train_ac_single import TrainerSingleCard | |
from hcpdiff.utils import load_config_with_cli | |
from .embedding import create_embedding, _DEFAULT_TRAIN_MODEL | |
from ..dataset import load_dataset_for_character, save_recommended_tags | |
from ..utils import data_to_cli_args, get_ch_name | |
_DEFAULT_TRAIN_CFG = 'cfgs/train/examples/lora_anime_character.yaml' | |
def _min_training_steps(dataset_size: int, unit: int = 20): | |
steps = 4000.9 + (720.9319 - 4000.9) / (1 + (dataset_size / 297.2281) ** 0.6543184) | |
return int(round(steps / unit)) * unit | |
def train_plora( | |
source: Union[str, Character], name: Optional[str] = None, | |
epochs: int = 13, min_steps: Optional[int] = None, | |
save_for_times: int = 15, no_min_steps: bool = False, | |
batch_size: int = 4, pretrained_model: str = _DEFAULT_TRAIN_MODEL, | |
workdir: str = None, emb_n_words: int = 4, emb_init_text: str = '*[0.017, 1]', | |
unet_rank: float = 8, text_encoder_rank: float = 4, | |
cfg_file: str = _DEFAULT_TRAIN_CFG, single_card: bool = True, | |
dataset_type: str = 'stage3-1200', use_ratio: bool = True, | |
): | |
with load_dataset_for_character(source, dataset_type) as (ch, ds_dir): | |
if ch is None: | |
if name is None: | |
raise ValueError(f'Name should be specified when using custom source - {source!r}.') | |
else: | |
name = name or get_ch_name(ch) | |
dataset_size = len(glob.glob(os.path.join(ds_dir, '*.png'))) | |
logging.info(f'{plural_word(dataset_size, "image")} found in dataset.') | |
actual_steps = epochs * dataset_size | |
if not no_min_steps: | |
actual_steps = max(actual_steps, _min_training_steps(dataset_size, 20)) | |
if min_steps is not None: | |
actual_steps = max(actual_steps, min_steps) | |
save_per_steps = max(int(math.ceil(actual_steps / save_for_times / 20) * 20), 20) | |
steps = int(math.ceil(actual_steps / save_per_steps) * save_per_steps) | |
epochs = int(math.ceil(steps / dataset_size)) | |
logging.info(f'Training for {plural_word(steps, "step")}, {plural_word(epochs, "epoch")}, ' | |
f'save per {plural_word(save_per_steps, "step")} ...') | |
workdir = workdir or os.path.join('runs', name) | |
os.makedirs(workdir, exist_ok=True) | |
# os.makedirs(workdir) | |
save_recommended_tags(ds_dir, name, workdir) | |
with open(os.path.join(workdir, 'meta.json'), 'w', encoding='utf-8') as f: | |
json.dump({ | |
'dataset': { | |
'size': dataset_size, | |
'type': dataset_type, | |
}, | |
}, f, indent=4, sort_keys=True, ensure_ascii=False) | |
with TemporaryDirectory() as embs_dir: | |
logging.info(f'Creating embeddings {name!r} at {embs_dir!r}, ' | |
f'n_words: {emb_n_words!r}, init_text: {emb_init_text!r}, ' | |
f'pretrained_model: {pretrained_model!r}.') | |
create_embedding( | |
name, emb_n_words, emb_init_text, | |
replace=True, | |
pretrained_model=pretrained_model, | |
embs_dir=embs_dir, | |
) | |
cli_args = data_to_cli_args({ | |
'train': { | |
'train_steps': steps, | |
'save_step': save_per_steps, | |
'scheduler': { | |
'num_training_steps': steps, | |
} | |
}, | |
'model': { | |
'pretrained_model_name_or_path': pretrained_model, | |
}, | |
'character_name': name, | |
'dataset_dir': ds_dir, | |
'exp_dir': workdir, | |
'unet_rank': unet_rank, | |
'text_encoder_rank': text_encoder_rank, | |
'tokenizer_pt': { | |
'emb_dir': embs_dir, | |
}, | |
'data': { | |
'dataset1': { | |
'batch_size': batch_size, | |
'bucket': { | |
'_target_': 'hcpdiff.data.bucket.RatioBucket.from_files', | |
'target_area': '${times:512,512}', | |
'num_bucket': 5, | |
} if use_ratio else { | |
'_target_': 'hcpdiff.data.bucket.SizeBucket.from_files', | |
'target_area': '---', | |
'num_bucket': 1, | |
} | |
}, | |
}, | |
}) | |
conf = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg | |
logging.info(f'Training with {cfg_file!r}, args: {cli_args!r} ...') | |
if single_card: | |
logging.info('Training with single card ...') | |
trainer = TrainerSingleCard(conf) | |
else: | |
logging.info('Training with non-single cards ...') | |
trainer = Trainer(conf) | |
trainer.train() | |