|
from pathlib import Path |
|
import time |
|
import os |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
from optimum.onnxruntime import ORTModelForCausalLM |
|
|
|
|
|
import gradio as gr |
|
|
|
MODEL_NAME = ( |
|
os.environ.get("MODEL_NAME") |
|
if os.environ.get("MODEL_NAME") is not None |
|
else "p1atdev/dart-v1-sft" |
|
) |
|
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN") |
|
MODEL_BACKEND = ( |
|
os.environ.get("MODEL_BACKEND") |
|
if os.environ.get("MODEL_BACKEND") is not None |
|
else "ONNX (quantized)" |
|
) |
|
|
|
assert isinstance(MODEL_NAME, str) |
|
assert isinstance(MODEL_BACKEND, str) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
token=HF_READ_TOKEN, |
|
) |
|
model = { |
|
"default": AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
token=HF_READ_TOKEN, |
|
), |
|
"ort": ORTModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
), |
|
"ort_qantized": ORTModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
file_name="model_quantized.onnx", |
|
), |
|
} |
|
|
|
MODEL_BACKEND_MAP = { |
|
"Default": "default", |
|
"ONNX (normal)": "ort", |
|
"ONNX (quantized)": "ort_qantized", |
|
} |
|
|
|
try: |
|
model["default"].to("cuda") |
|
except: |
|
print("No GPU") |
|
|
|
try: |
|
model["default"] = torch.compile(model["default"]) |
|
except: |
|
print("torch.compile is not supported") |
|
|
|
BOS = "<|bos|>" |
|
EOS = "<|eos|>" |
|
RATING_BOS = "<rating>" |
|
RATING_EOS = "</rating>" |
|
COPYRIGHT_BOS = "<copyright>" |
|
COPYRIGHT_EOS = "</copyright>" |
|
CHARACTER_BOS = "<character>" |
|
CHARACTER_EOS = "</character>" |
|
GENERAL_BOS = "<general>" |
|
GENERAL_EOS = "</general>" |
|
|
|
INPUT_END = "<|input_end|>" |
|
|
|
LENGTH_VERY_SHORT = "<|very_short|>" |
|
LENGTH_SHORT = "<|short|>" |
|
LENGTH_LONG = "<|long|>" |
|
LENGTH_VERY_LONG = "<|very_long|>" |
|
|
|
RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS) |
|
RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS) |
|
COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS) |
|
COPYRIGHT_EOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_EOS) |
|
CHARACTER_BOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_BOS) |
|
CHARACTER_EOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_EOS) |
|
GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS) |
|
GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS) |
|
|
|
assert isinstance(RATING_BOS_ID, int) |
|
assert isinstance(RATING_EOS_ID, int) |
|
assert isinstance(COPYRIGHT_BOS_ID, int) |
|
assert isinstance(COPYRIGHT_EOS_ID, int) |
|
assert isinstance(CHARACTER_BOS_ID, int) |
|
assert isinstance(CHARACTER_EOS_ID, int) |
|
assert isinstance(GENERAL_BOS_ID, int) |
|
assert isinstance(GENERAL_EOS_ID, int) |
|
|
|
SPECIAL_TAGS = [ |
|
BOS, |
|
EOS, |
|
RATING_BOS, |
|
RATING_EOS, |
|
COPYRIGHT_BOS, |
|
COPYRIGHT_EOS, |
|
CHARACTER_BOS, |
|
CHARACTER_EOS, |
|
GENERAL_BOS, |
|
GENERAL_EOS, |
|
INPUT_END, |
|
LENGTH_VERY_SHORT, |
|
LENGTH_SHORT, |
|
LENGTH_LONG, |
|
LENGTH_VERY_LONG, |
|
] |
|
|
|
SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS) |
|
assert isinstance(SPECIAL_TAG_IDS, list) |
|
assert all([token_id != tokenizer.unk_token_id for token_id in SPECIAL_TAG_IDS]) |
|
|
|
RATING_TAGS = { |
|
"sfw": "rating:sfw", |
|
"nsfw": "rating:nsfw", |
|
"general": "rating:general", |
|
"sensitive": "rating:sensitive", |
|
"questionable": "rating:questionable", |
|
"explicit": "rating:explicit", |
|
} |
|
RATING_TAG_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in RATING_TAGS.items()} |
|
|
|
LENGTH_TAGS = { |
|
"very short": LENGTH_VERY_SHORT, |
|
"short": LENGTH_SHORT, |
|
"long": LENGTH_LONG, |
|
"very long": LENGTH_VERY_LONG, |
|
} |
|
|
|
|
|
def load_tags(path: str | Path): |
|
if isinstance(path, str): |
|
path = Path(path) |
|
|
|
with open(path, "r", encoding="utf-8") as file: |
|
lines = [line.strip() for line in file.readlines() if line.strip() != ""] |
|
|
|
return lines |
|
|
|
|
|
COPYRIGHT_TAGS_LIST: list[str] = load_tags("./tags/copyright.txt") |
|
CHARACTER_TAGS_LIST: list[str] = load_tags("./tags/character.txt") |
|
PEOPLE_TAGS_LIST: list[str] = load_tags("./tags/people.txt") |
|
|
|
PEOPLE_TAG_IDS_LIST = tokenizer.convert_tokens_to_ids(PEOPLE_TAGS_LIST) |
|
|
|
assert isinstance(PEOPLE_TAG_IDS_LIST, list) |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
input_text: str, |
|
model_backend: str, |
|
max_new_tokens: int = 128, |
|
min_new_tokens: int = 0, |
|
do_sample: bool = True, |
|
temperature: float = 1.0, |
|
top_p: float = 1, |
|
top_k: int = 20, |
|
num_beams: int = 1, |
|
bad_words_ids: list[int] | None = None, |
|
cfg_scale: float = 1.5, |
|
negative_input_text: str | None = None, |
|
) -> list[int]: |
|
inputs = tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device) |
|
negative_inputs = ( |
|
tokenizer( |
|
negative_input_text, |
|
return_tensors="pt", |
|
).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device) |
|
if negative_input_text is not None |
|
else None |
|
) |
|
|
|
generated = model[MODEL_BACKEND_MAP[model_backend]].generate( |
|
inputs, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=min_new_tokens, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_beams=num_beams, |
|
bad_words_ids=( |
|
[[token] for token in bad_words_ids] if bad_words_ids is not None else None |
|
), |
|
negative_prompt_ids=negative_inputs, |
|
guidance_scale=cfg_scale, |
|
no_repeat_ngram_size=1, |
|
)[0] |
|
|
|
return generated.tolist() |
|
|
|
|
|
def decode_normal(token_ids: list[int], skip_special_tokens: bool = True): |
|
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) |
|
|
|
|
|
def decode_general_only(token_ids: list[int]): |
|
token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :] |
|
decoded = tokenizer.decode(token_ids, skip_special_tokens=True) |
|
tags = [tag for tag in decoded.split(", ")] |
|
tags = sorted(tags) |
|
return ", ".join(tags) |
|
|
|
|
|
def split_people_tokens_part(token_ids: list[int]): |
|
people_tokens = [] |
|
other_tokens = [] |
|
|
|
for token in token_ids: |
|
if token in PEOPLE_TAG_IDS_LIST: |
|
people_tokens.append(token) |
|
else: |
|
other_tokens.append(token) |
|
|
|
return people_tokens, other_tokens |
|
|
|
|
|
def decode_animagine(token_ids: list[int]): |
|
def get_part(eos_token_id: int, remains_part: list[int]): |
|
part = [] |
|
for i, token_id in enumerate(remains_part): |
|
if token_id == eos_token_id: |
|
return part, remains_part[i:] |
|
|
|
part.append(token_id) |
|
|
|
raise Exception("The provided EOS token was not found in the token_ids.") |
|
|
|
|
|
rating_part, remains = get_part(RATING_EOS_ID, token_ids) |
|
copyright_part, remains = get_part(COPYRIGHT_EOS_ID, remains) |
|
character_part, remains = get_part(CHARACTER_EOS_ID, remains) |
|
general_part, _ = get_part(GENERAL_EOS_ID, remains) |
|
|
|
|
|
people_part, other_general_part = split_people_tokens_part(general_part) |
|
|
|
|
|
rating_part = [token for token in rating_part if token != RATING_TAG_IDS["sfw"]] |
|
|
|
|
|
rearranged_tokens = ( |
|
people_part + character_part + copyright_part + other_general_part + rating_part |
|
) |
|
rearranged_tokens = [ |
|
token for token in rearranged_tokens if token not in SPECIAL_TAG_IDS |
|
] |
|
|
|
decoded = tokenizer.decode(rearranged_tokens, skip_special_tokens=True) |
|
|
|
|
|
decoded = decoded.replace("rating:nsfw", "nsfw") |
|
|
|
return decoded |
|
|
|
|
|
def prepare_rating_tags(rating: str): |
|
tag = RATING_TAGS[rating] |
|
if tag in [RATING_TAGS["general"], RATING_TAGS["sensitive"]]: |
|
parent = RATING_TAGS["sfw"] |
|
else: |
|
parent = RATING_TAGS["nsfw"] |
|
|
|
return f"{parent}, {tag}" |
|
|
|
|
|
def handle_inputs( |
|
rating_tags: str, |
|
copyright_tags_list: list[str], |
|
character_tags_list: list[str], |
|
general_tags: str, |
|
ban_tags: str, |
|
do_cfg: bool = False, |
|
cfg_scale: float = 1.5, |
|
negative_tags: str = "", |
|
total_token_length: str = "long", |
|
max_new_tokens: int = 128, |
|
min_new_tokens: int = 0, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
top_k: int = 20, |
|
num_beams: int = 1, |
|
|
|
): |
|
""" |
|
Returns: |
|
[ |
|
output_tags_natural, |
|
output_tags_general_only, |
|
output_tags_animagine, |
|
input_prompt_raw, |
|
output_tags_raw, |
|
elapsed_time, |
|
output_tags_natural_copy_btn, |
|
output_tags_general_only_copy_btn, |
|
output_tags_animagine_copy_btn |
|
] |
|
""" |
|
|
|
start_time = time.time() |
|
|
|
copyright_tags = ", ".join(copyright_tags_list) |
|
character_tags = ", ".join(character_tags_list) |
|
|
|
token_length_tag = LENGTH_TAGS[total_token_length] |
|
|
|
prompt: str = tokenizer.apply_chat_template( |
|
{ |
|
"rating": prepare_rating_tags(rating_tags), |
|
"copyright": copyright_tags, |
|
"character": character_tags, |
|
"general": general_tags, |
|
"length": token_length_tag, |
|
}, |
|
tokenize=False, |
|
) |
|
|
|
negative_prompt: str = tokenizer.apply_chat_template( |
|
{ |
|
"rating": prepare_rating_tags(rating_tags), |
|
"copyright": "", |
|
"character": "", |
|
"general": negative_tags, |
|
"length": token_length_tag, |
|
}, |
|
tokenize=False, |
|
) |
|
|
|
bad_words_ids = tokenizer.encode_plus( |
|
ban_tags if negative_tags.strip() == "" else ban_tags + ", " + negative_tags |
|
).input_ids |
|
|
|
generated_ids = generate( |
|
prompt, |
|
model_backend=MODEL_BACKEND, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=min_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_beams=num_beams, |
|
bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None, |
|
cfg_scale=cfg_scale, |
|
negative_input_text=negative_prompt if do_cfg else None, |
|
) |
|
|
|
decoded_normal = decode_normal(generated_ids, skip_special_tokens=True) |
|
decoded_general_only = decode_general_only(generated_ids) |
|
decoded_animagine = decode_animagine(generated_ids) |
|
decoded_raw = decode_normal(generated_ids, skip_special_tokens=False) |
|
|
|
end_time = time.time() |
|
elapsed_time = f"Elapsed: {(end_time - start_time) * 1000:.2f} ms" |
|
|
|
|
|
set_visible = gr.update(visible=True) |
|
|
|
return [ |
|
decoded_normal, |
|
decoded_general_only, |
|
decoded_animagine, |
|
prompt, |
|
decoded_raw, |
|
elapsed_time, |
|
set_visible, |
|
set_visible, |
|
set_visible, |
|
] |
|
|
|
|
|
|
|
def copy_text(_text: None): |
|
gr.Info("Copied!") |
|
|
|
|
|
COPY_ACTION_JS = """\ |
|
(inputs, _outputs) => { |
|
// inputs is the string value of the input_text |
|
if (inputs.trim() !== "") { |
|
navigator.clipboard.writeText(inputs); |
|
} |
|
}""" |
|
|
|
|
|
def demo(): |
|
with gr.Blocks() as ui: |
|
gr.Markdown( |
|
"""\ |
|
# Danbooru Tags Transformer Demo |
|
|
|
Collection: [Dart (Danbooru Tags Transformer)](https://huggingface.co/collections/p1atdev/dart-danbooru-tags-transformer-65d687604ff57dc62ae40945) |
|
|
|
Models: |
|
|
|
- [p1atdev/dart-v1-sft](https://huggingface.co/p1atdev/dart-v1-sft) |
|
- [p1atdev/dart-v1-base](https://huggingface.co/p1atdev/dart-v1-base) |
|
|
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Group(): |
|
rating_dropdown = gr.Dropdown( |
|
label="Rating", |
|
choices=[ |
|
"general", |
|
"sensitive", |
|
"questionable", |
|
"explicit", |
|
], |
|
value="general", |
|
) |
|
|
|
with gr.Group(): |
|
copyright_tags_mode_dropdown = gr.Dropdown( |
|
label="Copyright tags mode", |
|
choices=[ |
|
"None", |
|
"Original", |
|
|
|
|
|
"Custom", |
|
], |
|
value="None", |
|
interactive=True, |
|
) |
|
copyright_tags_dropdown = gr.Dropdown( |
|
label="Copyright tags", |
|
choices=COPYRIGHT_TAGS_LIST, |
|
value=[], |
|
multiselect=True, |
|
visible=False, |
|
) |
|
|
|
def on_change_copyright_tags_dropdouwn(mode: str): |
|
kwargs: dict = {"visible": mode == "Custom"} |
|
if mode == "Original": |
|
kwargs["value"] = ["original"] |
|
elif mode == "None": |
|
kwargs["value"] = [] |
|
|
|
return gr.update(**kwargs) |
|
|
|
with gr.Group(): |
|
character_tags_mode_dropdown = gr.Dropdown( |
|
label="Character tags mode", |
|
choices=[ |
|
"None", |
|
|
|
|
|
"Custom", |
|
], |
|
value="None", |
|
interactive=True, |
|
) |
|
character_tags_dropdown = gr.Dropdown( |
|
label="Character tags", |
|
choices=CHARACTER_TAGS_LIST, |
|
value=[], |
|
multiselect=True, |
|
visible=False, |
|
) |
|
|
|
def on_change_character_tags_dropdouwn(mode: str): |
|
kwargs: dict = {"visible": mode == "Custom"} |
|
if mode == "None": |
|
kwargs["value"] = [] |
|
|
|
return gr.update(**kwargs) |
|
|
|
with gr.Group(): |
|
general_tags_textbox = gr.Textbox( |
|
label="General tags (the condition to generate tags)", |
|
value="", |
|
placeholder="1girl, ...", |
|
lines=4, |
|
) |
|
|
|
ban_tags_textbox = gr.Textbox( |
|
label="Ban tags (tags in this field never appear in generation)", |
|
value="", |
|
placeholder="official alternate cosutme, english text,...", |
|
lines=2, |
|
) |
|
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Accordion(label="Generation config (advanced)", open=False): |
|
with gr.Group(): |
|
do_cfg_check = gr.Checkbox( |
|
label="Do CFG (Classifier Free Guidance)", |
|
value=False, |
|
) |
|
cfg_scale_slider = gr.Slider( |
|
label="CFG scale", |
|
maximum=3.0, |
|
minimum=0.1, |
|
step=0.1, |
|
value=1.5, |
|
visible=False, |
|
) |
|
negative_tags_textbox = gr.Textbox( |
|
label="Negative prompt", |
|
placeholder="simple background, ...", |
|
value="", |
|
lines=2, |
|
visible=False, |
|
) |
|
|
|
def on_change_do_cfg_check(do_cfg: bool): |
|
kwargs: dict = {"visible": do_cfg} |
|
return gr.update(**kwargs), gr.update(**kwargs) |
|
|
|
do_cfg_check.change( |
|
on_change_do_cfg_check, |
|
inputs=[do_cfg_check], |
|
outputs=[cfg_scale_slider, negative_tags_textbox], |
|
) |
|
|
|
with gr.Group(): |
|
total_token_length_radio = gr.Radio( |
|
label="Total token length", |
|
choices=list(LENGTH_TAGS.keys()), |
|
value="long", |
|
) |
|
|
|
with gr.Group(): |
|
max_new_tokens_slider = gr.Slider( |
|
label="Max new tokens", |
|
maximum=256, |
|
minimum=1, |
|
step=1, |
|
value=128, |
|
) |
|
min_new_tokens_slider = gr.Slider( |
|
label="Min new tokens", |
|
maximum=255, |
|
minimum=0, |
|
step=1, |
|
value=0, |
|
) |
|
temperature_slider = gr.Slider( |
|
label="Temperature (larger is more random)", |
|
maximum=1.0, |
|
minimum=0.0, |
|
step=0.1, |
|
value=1.0, |
|
) |
|
top_p_slider = gr.Slider( |
|
label="Top p (larger is more random)", |
|
maximum=1.0, |
|
minimum=0.0, |
|
step=0.1, |
|
value=1.0, |
|
) |
|
top_k_slider = gr.Slider( |
|
label="Top k (larger is more random)", |
|
maximum=500, |
|
minimum=1, |
|
step=1, |
|
value=100, |
|
) |
|
num_beams_slider = gr.Slider( |
|
label="Number of beams (smaller is more random)", |
|
maximum=10, |
|
minimum=1, |
|
step=1, |
|
value=1, |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Group(): |
|
output_tags_natural = gr.Textbox( |
|
label="Generation result", |
|
|
|
interactive=False, |
|
) |
|
output_tags_natural_copy_btn = gr.Button("Copy", visible=False) |
|
output_tags_natural_copy_btn.click( |
|
fn=copy_text, |
|
inputs=[output_tags_natural], |
|
js=COPY_ACTION_JS, |
|
) |
|
|
|
with gr.Group(): |
|
output_tags_general_only = gr.Textbox( |
|
label="General tags only (sorted)", |
|
interactive=False, |
|
) |
|
output_tags_general_only_copy_btn = gr.Button("Copy", visible=False) |
|
output_tags_general_only_copy_btn.click( |
|
fn=copy_text, |
|
inputs=[output_tags_general_only], |
|
js=COPY_ACTION_JS, |
|
) |
|
|
|
with gr.Group(): |
|
output_tags_animagine = gr.Textbox( |
|
label="Output tags (AnimagineXL v3 style order)", |
|
|
|
interactive=False, |
|
) |
|
output_tags_animagine_copy_btn = gr.Button("Copy", visible=False) |
|
output_tags_animagine_copy_btn.click( |
|
fn=copy_text, |
|
inputs=[output_tags_animagine], |
|
js=COPY_ACTION_JS, |
|
) |
|
|
|
with gr.Accordion(label="Metadata", open=False): |
|
_model_backend_md = gr.Markdown( |
|
f"Model backend: {MODEL_BACKEND}", |
|
) |
|
input_prompt_raw = gr.Textbox( |
|
label="Input prompt (raw)", |
|
interactive=False, |
|
lines=4, |
|
) |
|
|
|
output_tags_raw = gr.Textbox( |
|
label="Output tags (raw)", |
|
interactive=False, |
|
lines=4, |
|
) |
|
|
|
elapsed_time_md = gr.Markdown(value="Waiting to generate...") |
|
|
|
copyright_tags_mode_dropdown.change( |
|
on_change_copyright_tags_dropdouwn, |
|
inputs=[copyright_tags_mode_dropdown], |
|
outputs=[copyright_tags_dropdown], |
|
) |
|
character_tags_mode_dropdown.change( |
|
on_change_character_tags_dropdouwn, |
|
inputs=[character_tags_mode_dropdown], |
|
outputs=[character_tags_dropdown], |
|
) |
|
|
|
generate_btn.click( |
|
handle_inputs, |
|
inputs=[ |
|
rating_dropdown, |
|
copyright_tags_dropdown, |
|
character_tags_dropdown, |
|
general_tags_textbox, |
|
ban_tags_textbox, |
|
do_cfg_check, |
|
cfg_scale_slider, |
|
negative_tags_textbox, |
|
total_token_length_radio, |
|
max_new_tokens_slider, |
|
min_new_tokens_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
top_k_slider, |
|
num_beams_slider, |
|
|
|
], |
|
outputs=[ |
|
output_tags_natural, |
|
output_tags_general_only, |
|
output_tags_animagine, |
|
input_prompt_raw, |
|
output_tags_raw, |
|
elapsed_time_md, |
|
output_tags_natural_copy_btn, |
|
output_tags_general_only_copy_btn, |
|
output_tags_animagine_copy_btn, |
|
], |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["1girl, solo, from side", ""], |
|
["1girl, solo, abstract, from above", ""], |
|
["2girls, yuri", "1boy"], |
|
["no humans, scenery, summer, day", ""], |
|
], |
|
inputs=[ |
|
general_tags_textbox, |
|
ban_tags_textbox, |
|
], |
|
) |
|
|
|
ui.launch( |
|
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo() |
|
|