|
import gradio as gr |
|
from easygui import msgbox |
|
import subprocess |
|
from .common_gui import get_folder_path, add_pre_postfix |
|
import os |
|
|
|
from library.custom_logging import setup_logging |
|
|
|
|
|
log = setup_logging() |
|
|
|
|
|
def caption_images( |
|
train_data_dir, |
|
caption_extension, |
|
batch_size, |
|
general_threshold, |
|
character_threshold, |
|
replace_underscores, |
|
model, |
|
recursive, |
|
max_data_loader_n_workers, |
|
debug, |
|
undesired_tags, |
|
frequency_tags, |
|
prefix, |
|
postfix, |
|
): |
|
|
|
if train_data_dir == '': |
|
msgbox('Image folder is missing...') |
|
return |
|
|
|
if caption_extension == '': |
|
msgbox('Please provide an extension for the caption files.') |
|
return |
|
|
|
log.info(f'Captioning files in {train_data_dir}...') |
|
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' |
|
run_cmd += f' --batch_size={int(batch_size)}' |
|
run_cmd += f' --general_threshold={general_threshold}' |
|
run_cmd += f' --character_threshold={character_threshold}' |
|
run_cmd += f' --caption_extension="{caption_extension}"' |
|
run_cmd += f' --model="{model}"' |
|
run_cmd += ( |
|
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' |
|
) |
|
|
|
if recursive: |
|
run_cmd += f' --recursive' |
|
if debug: |
|
run_cmd += f' --debug' |
|
if replace_underscores: |
|
run_cmd += f' --remove_underscore' |
|
if frequency_tags: |
|
run_cmd += f' --frequency_tags' |
|
|
|
if not undesired_tags == '': |
|
run_cmd += f' --undesired_tags="{undesired_tags}"' |
|
run_cmd += f' "{train_data_dir}"' |
|
|
|
log.info(run_cmd) |
|
|
|
|
|
if os.name == 'posix': |
|
os.system(run_cmd) |
|
else: |
|
subprocess.run(run_cmd) |
|
|
|
|
|
add_pre_postfix( |
|
folder=train_data_dir, |
|
caption_file_ext=caption_extension, |
|
prefix=prefix, |
|
postfix=postfix, |
|
) |
|
|
|
log.info('...captioning done') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_wd14_caption_gui_tab(headless=False): |
|
with gr.Tab('WD14 Captioning'): |
|
gr.Markdown( |
|
'This utility will use WD14 to caption files for each images in a folder.' |
|
) |
|
|
|
|
|
|
|
with gr.Row(): |
|
train_data_dir = gr.Textbox( |
|
label='Image folder to caption', |
|
placeholder='Directory containing the images to caption', |
|
interactive=True, |
|
) |
|
button_train_data_dir_input = gr.Button( |
|
'📂', elem_id='open_folder_small', visible=(not headless) |
|
) |
|
button_train_data_dir_input.click( |
|
get_folder_path, |
|
outputs=train_data_dir, |
|
show_progress=False, |
|
) |
|
|
|
caption_extension = gr.Textbox( |
|
label='Caption file extension', |
|
placeholder='Extention for caption file. eg: .caption, .txt', |
|
value='.txt', |
|
interactive=True, |
|
) |
|
|
|
undesired_tags = gr.Textbox( |
|
label='Undesired tags', |
|
placeholder='(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.', |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(): |
|
prefix = gr.Textbox( |
|
label='Prefix to add to WD14 caption', |
|
placeholder='(Optional)', |
|
interactive=True, |
|
) |
|
|
|
postfix = gr.Textbox( |
|
label='Postfix to add to WD14 caption', |
|
placeholder='(Optional)', |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(): |
|
replace_underscores = gr.Checkbox( |
|
label='Replace underscores in filenames with spaces', |
|
value=True, |
|
interactive=True, |
|
) |
|
recursive = gr.Checkbox( |
|
label='Recursive', |
|
value=False, |
|
info='Tag subfolders images as well', |
|
) |
|
|
|
debug = gr.Checkbox( |
|
label='Verbose logging', |
|
value=True, |
|
info='Debug while tagging, it will print your image file with general tags and character tags.', |
|
) |
|
frequency_tags = gr.Checkbox( |
|
label='Show tags frequency', |
|
value=True, |
|
info='Show frequency of tags for images.', |
|
) |
|
|
|
|
|
with gr.Row(): |
|
model = gr.Dropdown( |
|
label='Model', |
|
choices=[ |
|
'SmilingWolf/wd-v1-4-convnext-tagger-v2', |
|
'SmilingWolf/wd-v1-4-convnextv2-tagger-v2', |
|
'SmilingWolf/wd-v1-4-vit-tagger-v2', |
|
'SmilingWolf/wd-v1-4-swinv2-tagger-v2', |
|
], |
|
value='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', |
|
) |
|
|
|
general_threshold = gr.Slider( |
|
value=0.35, |
|
label='General threshold', |
|
info='Adjust `general_threshold` for pruning tags (less tags, less flexible)', |
|
minimum=0, |
|
maximum=1, |
|
step=0.05, |
|
) |
|
character_threshold = gr.Slider( |
|
value=0.35, |
|
label='Character threshold', |
|
info='useful if you want to train with character', |
|
minimum=0, |
|
maximum=1, |
|
step=0.05, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
batch_size = gr.Number( |
|
value=8, label='Batch size', interactive=True |
|
) |
|
|
|
max_data_loader_n_workers = gr.Number( |
|
value=2, label='Max dataloader workers', interactive=True |
|
) |
|
|
|
caption_button = gr.Button('Caption images') |
|
|
|
caption_button.click( |
|
caption_images, |
|
inputs=[ |
|
train_data_dir, |
|
caption_extension, |
|
batch_size, |
|
general_threshold, |
|
character_threshold, |
|
replace_underscores, |
|
model, |
|
recursive, |
|
max_data_loader_n_workers, |
|
debug, |
|
undesired_tags, |
|
frequency_tags, |
|
prefix, |
|
postfix, |
|
], |
|
show_progress=False, |
|
) |
|
|