File size: 3,081 Bytes
13e5846
b411b2e
13e5846
 
10d19d1
b411b2e
13e5846
 
 
 
 
10d19d1
13e5846
 
 
 
 
 
 
 
 
 
 
 
10d19d1
13e5846
 
 
 
 
 
10d19d1
13e5846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d19d1
13e5846
 
 
10d19d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e5846
 
 
10d19d1
 
 
 
 
 
 
 
 
 
 
 
13e5846
 
10d19d1
13e5846
b411b2e
10d19d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import gradio as gr
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer
from huggingface_hub import InferenceClient

# https://huggingface.co/collections/p1atdev/dart-v2-danbooru-tags-transformer-v2-66291115701b6fe773399b0a
model_id = "p1atdev/dart-v2-sft"
model = ORTModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)
txt2imgclient = InferenceClient()


# https://huggingface.co/docs/transformers/v4.44.2/en/internal/generation_utils#transformers.NoBadWordsLogitsProcessor
def get_tokens_as_list(word_list):
    "Converts a sequence of words into a list of tokens"
    tokens_list = []
    for word in word_list:
        tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
        tokens_list.append(tokenized_word)
    return tokens_list


def generate_tags(general_tags: str, generate_image: bool = False):
    # https://huggingface.co/p1atdev/dart-v2-sft#prompt-format
    general_tags = ",".join(tag.strip() for tag in general_tags.split(",") if tag)
    prompt = (
        "<|bos|>"
        # "<copyright></copyright>"
        # "<character></character>"
        "<|rating:general|><|aspect_ratio:tall|><|length:medium|>"
        f"<general>{general_tags}<|identity:none|><|input_end|>"
    )

    inputs = tokenizer(prompt, return_tensors="pt").input_ids
    # bad_words_ids = get_tokens_as_list(word_list=[""])

    with torch.no_grad():
        outputs = model.generate(
            inputs,
            do_sample=True,
            temperature=1.0,
            top_p=1.0,
            top_k=100,
            max_new_tokens=128,
            num_beams=1,
            # bad_words_ids=bad_words_ids,
        )

    output_tags = ", ".join(
        [tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""]
    )

    yield (output_tags, None)

    if generate_image:
        txt2img_prompt = f"score_9, score_8_up, score_7_up, {output_tags}"
        img = txt2imgclient.text_to_image(
            prompt=txt2img_prompt,
            negative_prompt="score_6, score_5, score_4, rating_explicit, child, loli, shota",
            num_inference_steps=25,
            height=1152,
            width=896,
            model="John6666/wai-real-mix-v8-sdxl",
            scheduler="EulerAncestralDiscreteScheduler",
        )

        yield (output_tags, img)


demo = gr.Interface(
    fn=generate_tags,
    inputs=[
        gr.TextArea("1girl, black hair", lines=4),
        gr.Checkbox(
            False,
            label="Generate Image",
            info="Generating image using InferenceClient (really slow) with output_tags as prompt",
        ),
    ],
    outputs=[
        gr.Textbox(label="output_tags", show_copy_button=True),
        gr.Image(label="generated_image", format="jpeg", type="pil"),
    ],
    clear_btn=None,
    analytics_enabled=False,
    concurrency_limit=64,
)

demo.queue().launch()