kohya_ss / kohya_gui /wd14_caption_gui.py
zengxi123's picture
Upload folder using huggingface_hub
fb83c5b verified
import gradio as gr
import subprocess
from .common_gui import (
get_folder_path,
add_pre_postfix,
scriptdir,
list_dirs,
get_executable_path, setup_environment,
)
from .class_gui_config import KohyaSSGUIConfig
import os
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
old_onnx_value = True
def caption_images(
train_data_dir: str,
caption_extension: str,
batch_size: int,
general_threshold: float,
character_threshold: float,
repo_id: str,
recursive: bool,
max_data_loader_n_workers: int,
debug: bool,
undesired_tags: str,
frequency_tags: bool,
always_first_tags: str,
onnx: bool,
append_tags: bool,
force_download: bool,
caption_separator: str,
tag_replacement: bool,
character_tag_expand: str,
use_rating_tags: bool,
use_rating_tags_as_last_tag: bool,
remove_underscore: bool,
thresh: float,
) -> None:
# Check for images_dir_input
if train_data_dir == "":
log.info("Image folder is missing...")
return
if caption_extension == "":
log.info("Please provide an extension for the caption files.")
return
repo_id_converted = repo_id.replace("/", "_")
if not os.path.exists(f"./wd14_tagger_model/{repo_id_converted}"):
force_download = True
log.info(f"Captioning files in {train_data_dir}...")
run_cmd = [
rf'{get_executable_path("accelerate")}',
"launch",
rf"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py",
]
# Uncomment and modify if needed
# if always_first_tags != "":
# run_cmd.append('--always_first_tags')
# run_cmd.append(always_first_tags)
if append_tags:
run_cmd.append("--append_tags")
run_cmd.append("--batch_size")
run_cmd.append(str(int(batch_size)))
run_cmd.append("--caption_extension")
run_cmd.append(caption_extension)
run_cmd.append("--caption_separator")
run_cmd.append(caption_separator)
if character_tag_expand:
run_cmd.append("--character_tag_expand")
if not character_threshold == 0.35:
run_cmd.append("--character_threshold")
run_cmd.append(str(character_threshold))
if debug:
run_cmd.append("--debug")
if force_download:
run_cmd.append("--force_download")
if frequency_tags:
run_cmd.append("--frequency_tags")
if not general_threshold == 0.35:
run_cmd.append("--general_threshold")
run_cmd.append(str(general_threshold))
run_cmd.append("--max_data_loader_n_workers")
run_cmd.append(str(int(max_data_loader_n_workers)))
if onnx:
run_cmd.append("--onnx")
if recursive:
run_cmd.append("--recursive")
if remove_underscore:
run_cmd.append("--remove_underscore")
run_cmd.append("--repo_id")
run_cmd.append(repo_id)
if not tag_replacement == "":
run_cmd.append("--tag_replacement")
run_cmd.append(tag_replacement)
if not thresh == 0.35:
run_cmd.append("--thresh")
run_cmd.append(str(thresh))
if not undesired_tags == "":
run_cmd.append("--undesired_tags")
run_cmd.append(undesired_tags)
if use_rating_tags:
run_cmd.append("--use_rating_tags")
if use_rating_tags_as_last_tag:
run_cmd.append("--use_rating_tags_as_last_tag")
# Add the directory containing the training data
run_cmd.append(rf"{train_data_dir}")
env = setup_environment()
# Reconstruct the safe command string for display
command_to_run = " ".join(run_cmd)
log.info(f"Executing command: {command_to_run}")
# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, env=env)
# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_extension,
prefix=always_first_tags,
recursive=recursive,
)
log.info("...captioning done")
###
# Gradio UI
###
def gradio_wd14_caption_gui_tab(
headless=False,
default_train_dir=None,
config: KohyaSSGUIConfig = {},
):
from .common_gui import create_refresh_button
default_train_dir = (
default_train_dir
if default_train_dir is not None
else os.path.join(scriptdir, "data")
)
current_train_dir = default_train_dir
def list_train_dirs(path):
nonlocal current_train_dir
current_train_dir = path
return list(list_dirs(path))
with gr.Tab("WD14 Captioning"):
gr.Markdown(
"This utility will use WD14 to caption files for each images in a folder."
)
# Input Settings
# with gr.Section('Input Settings'):
with gr.Group(), gr.Row():
train_data_dir = gr.Dropdown(
label="Image folder to caption (containing the images to caption)",
choices=[config.get("wd14_caption.train_data_dir", "")]
+ list_train_dirs(default_train_dir),
value=config.get("wd14_caption.train_data_dir", ""),
interactive=True,
allow_custom_value=True,
)
create_refresh_button(
train_data_dir,
lambda: None,
lambda: {"choices": list_train_dirs(current_train_dir)},
"open_folder_small",
)
button_train_data_dir_input = gr.Button(
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_train_data_dir_input.click(
get_folder_path,
outputs=train_data_dir,
show_progress=False,
)
repo_id = gr.Dropdown(
label="Repo ID",
choices=[
"SmilingWolf/wd-v1-4-convnext-tagger-v2",
"SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
"SmilingWolf/wd-v1-4-vit-tagger-v2",
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
"SmilingWolf/wd-v1-4-moat-tagger-v2",
"SmilingWolf/wd-swinv2-tagger-v3",
"SmilingWolf/wd-vit-tagger-v3",
"SmilingWolf/wd-convnext-tagger-v3",
],
value=config.get(
"wd14_caption.repo_id", "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
),
show_label="Repo id for wd14 tagger on Hugging Face",
)
force_download = gr.Checkbox(
label="Force model re-download",
value=config.get("wd14_caption.force_download", False),
info="Useful to force model re download when switching to onnx",
)
with gr.Row():
caption_extension = gr.Dropdown(
label="Caption file extension",
choices=[".cap", ".caption", ".txt"],
value=".txt",
interactive=True,
allow_custom_value=True,
)
caption_separator = gr.Textbox(
label="Caption Separator",
value=config.get("wd14_caption.caption_separator", ", "),
interactive=True,
)
with gr.Row():
tag_replacement = gr.Textbox(
label="Tag replacement",
info="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`",
value=config.get("wd14_caption.tag_replacement", ""),
interactive=True,
)
character_tag_expand = gr.Checkbox(
label="Character tag expand",
info="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`",
value=config.get("wd14_caption.character_tag_expand", False),
interactive=True,
)
undesired_tags = gr.Textbox(
label="Undesired tags",
placeholder="(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.",
interactive=True,
value=config.get("wd14_caption.undesired_tags", ""),
)
with gr.Row():
always_first_tags = gr.Textbox(
label="Prefix to add to WD14 caption",
info="comma-separated list of tags to always put at the beginning, e.g.: 1girl, 1boy, ",
placeholder="(Optional)",
interactive=True,
value=config.get("wd14_caption.always_first_tags", ""),
)
with gr.Row():
onnx = gr.Checkbox(
label="Use onnx",
value=config.get("wd14_caption.onnx", True),
interactive=True,
info="https://github.com/onnx/onnx",
)
append_tags = gr.Checkbox(
label="Append TAGs",
value=config.get("wd14_caption.append_tags", False),
interactive=True,
info="This option appends the tags to the existing tags, instead of replacing them.",
)
use_rating_tags = gr.Checkbox(
label="Use rating tags",
value=config.get("wd14_caption.use_rating_tags", False),
interactive=True,
info="Adds rating tags as the first tag",
)
use_rating_tags_as_last_tag = gr.Checkbox(
label="Use rating tags as last tag",
value=config.get("wd14_caption.use_rating_tags_as_last_tag", False),
interactive=True,
info="Adds rating tags as the last tag",
)
with gr.Row():
recursive = gr.Checkbox(
label="Recursive",
value=config.get("wd14_caption.recursive", False),
info="Tag subfolders images as well",
)
remove_underscore = gr.Checkbox(
label="Remove underscore",
value=config.get("wd14_caption.remove_underscore", True),
info="replace underscores with spaces in the output tags",
)
debug = gr.Checkbox(
label="Debug",
value=config.get("wd14_caption.debug", True),
info="Debug mode",
)
frequency_tags = gr.Checkbox(
label="Show tags frequency",
value=config.get("wd14_caption.frequency_tags", True),
info="Show frequency of tags for images.",
)
with gr.Row():
thresh = gr.Slider(
value=config.get("wd14_caption.thresh", 0.35),
label="Threshold",
info="threshold of confidence to add a tag",
minimum=0,
maximum=1,
step=0.05,
)
general_threshold = gr.Slider(
value=config.get("wd14_caption.general_threshold", 0.35),
label="General threshold",
info="Adjust `general_threshold` for pruning tags (less tags, less flexible)",
minimum=0,
maximum=1,
step=0.05,
)
character_threshold = gr.Slider(
value=config.get("wd14_caption.character_threshold", 0.35),
label="Character threshold",
minimum=0,
maximum=1,
step=0.05,
)
# Advanced Settings
with gr.Row():
batch_size = gr.Number(
value=config.get("wd14_caption.batch_size", 1),
label="Batch size",
interactive=True,
)
max_data_loader_n_workers = gr.Number(
value=config.get("wd14_caption.max_data_loader_n_workers", 2),
label="Max dataloader workers",
interactive=True,
)
def repo_id_changes(repo_id, onnx):
global old_onnx_value
if "-v3" in repo_id:
old_onnx_value = onnx
return gr.Checkbox(value=True, interactive=False)
else:
return gr.Checkbox(value=old_onnx_value, interactive=True)
repo_id.change(repo_id_changes, inputs=[repo_id, onnx], outputs=[onnx])
caption_button = gr.Button("Caption images")
caption_button.click(
caption_images,
inputs=[
train_data_dir,
caption_extension,
batch_size,
general_threshold,
character_threshold,
repo_id,
recursive,
max_data_loader_n_workers,
debug,
undesired_tags,
frequency_tags,
always_first_tags,
onnx,
append_tags,
force_download,
caption_separator,
tag_replacement,
character_tag_expand,
use_rating_tags,
use_rating_tags_as_last_tag,
remove_underscore,
thresh,
],
show_progress=False,
)
train_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
inputs=train_data_dir,
outputs=train_data_dir,
show_progress=False,
)