|
import random |
|
from time import time_ns, sleep |
|
from threading import Lock |
|
|
|
import torch |
|
|
|
import spaces |
|
import gradio as gr |
|
from transformers import set_seed |
|
|
|
from kgen import models |
|
from diff import load_model, encode_prompts |
|
from dtg import process |
|
from meta import ( |
|
DEFAULT_STYLE_LIST, |
|
MODEL_FORMAT_LIST, |
|
MODEL_DEFAULT_QUALITY_LIST, |
|
DEFAULT_NEGATIVE_PROMPT, |
|
) |
|
|
|
|
|
sdxl_pipe = load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon-rev2", device="cuda") |
|
models.load_model(models.model_list[0]) |
|
models.text_model.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
current_dtg_model = models.model_list[0] |
|
current_sdxl_model = "KBlueLeaf/Kohaku-XL-Epsilon-rev2" |
|
|
|
model_loading_lock = Lock() |
|
model_running_lock = Lock() |
|
model_running = 0 |
|
|
|
|
|
@spaces.GPU |
|
def gen( |
|
sdxl_model: str, |
|
dtg_model: str, |
|
style: str, |
|
base_prompt: str, |
|
addon_prompt: str = "", |
|
seed: int = -1, |
|
): |
|
global current_dtg_model, current_sdxl_model, sdxl_pipe, model_running |
|
if sdxl_model != current_sdxl_model: |
|
with model_loading_lock: |
|
while model_running: |
|
sleep(0.01) |
|
sdxl_pipe = load_model(model_id=sdxl_model, device="cuda") |
|
current_sdxl_model = sdxl_model |
|
if dtg_model != current_dtg_model: |
|
with model_loading_lock: |
|
while model_running: |
|
sleep(0.01) |
|
models.load_model(dtg_model) |
|
current_dtg_model = dtg_model |
|
|
|
with model_loading_lock: |
|
pass |
|
|
|
with model_running_lock: |
|
model_running += 1 |
|
|
|
t0 = time_ns() |
|
seed = int(seed) |
|
if seed == -1: |
|
seed = random.randint(0, 2**31 - 1) |
|
|
|
prompt = ( |
|
f"{base_prompt}, {addon_prompt}, " |
|
f"{DEFAULT_STYLE_LIST[style]}, " |
|
f"{MODEL_DEFAULT_QUALITY_LIST[sdxl_model]}, " |
|
) |
|
models.text_model.cuda() |
|
full_prompt = process( |
|
prompt, |
|
aspect_ratio=1.0, |
|
seed=seed, |
|
tag_length="long", |
|
ban_tags=".*alternate.*, character doll, multiple.*, .*cosplay.*, .*name, .*text.*", |
|
format=MODEL_FORMAT_LIST[sdxl_model], |
|
temperature=1.0, |
|
) |
|
models.text_model.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( |
|
encode_prompts(sdxl_pipe, full_prompt, DEFAULT_NEGATIVE_PROMPT) |
|
) |
|
set_seed(seed) |
|
result = sdxl_pipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_embeds2, |
|
negative_pooled_prompt_embeds=neg_pooled_embeds2, |
|
num_inference_steps=24, |
|
width=1024, |
|
height=1024, |
|
guidance_scale=6.0, |
|
).images[0] |
|
torch.cuda.empty_cache() |
|
t1 = time_ns() |
|
|
|
with model_running_lock: |
|
model_running -= 1 |
|
|
|
return ( |
|
result.convert("RGB"), |
|
full_prompt, |
|
f"Cost: {(t1 - t0) / 1e9:.4}sec || Seed: {seed}", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("""# This Cute Dragon Girl Doesn't Exist""") |
|
with gr.Accordion("Introduction and Instructions", open=False): |
|
gr.Markdown( |
|
""" |
|
### What is this: |
|
"This Cute Dragon Girl Doesn't Exist" is a Demo for KGen System(DanTagGen) with SDXL anime models. |
|
It is aimed to show how the DanTagGen can be used to "refine/upsample" simple prompt to help the T2I model. |
|
|
|
Since I already have some application and demo on DanTagGen. |
|
This demo is designed to be more "simple" than before. |
|
|
|
Just one click, and get the result with high quality and high diversity. |
|
|
|
### How to use it: |
|
click "Next" button until you get the dragon girl you like. |
|
|
|
### Resources: |
|
- My anime model: [Kohaku XL Epsilon](https://huggingface.co/KBlueLeaf/Kohaku-XL-Epsilon) |
|
- DanTagGen: [DanTagGen](https://huggingface.co/KBlueLeaf/DanTagGen-beta) |
|
- DanTagGen extension: [z-a1111-sd-webui-dtg](https://github.com/KohakuBlueleaf/z-a1111-sd-webui-dtg) |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Row(): |
|
sdxl_model = gr.Dropdown( |
|
MODEL_FORMAT_LIST, |
|
label="SDXL Model", |
|
value=list(MODEL_FORMAT_LIST)[0], |
|
) |
|
dtg_model = gr.Dropdown( |
|
models.model_list, |
|
label="DTG Model", |
|
value=models.model_list[0], |
|
) |
|
with gr.Row(): |
|
base_prompt = gr.Textbox( |
|
label="Base prompt", |
|
lines=1, |
|
value="1girl, solo, dragon girl, dragon wings, dragon horns, dragon tail", |
|
interactive=False, |
|
) |
|
addon_propmt = gr.Textbox( |
|
label="Addon prompt", |
|
lines=1, |
|
value="cowboy shot", |
|
) |
|
with gr.Row(): |
|
seed = gr.Number( |
|
label="Seed (-1 for random)", |
|
value=-1, |
|
minimum=-1, |
|
maximum=2**31 - 1, |
|
precision=0, |
|
) |
|
style = gr.Dropdown( |
|
DEFAULT_STYLE_LIST, |
|
label="Style", |
|
value=list(DEFAULT_STYLE_LIST)[3], |
|
) |
|
submit = gr.Button("Next", variant="primary") |
|
dtg_output = gr.TextArea( |
|
label="DTG output", lines=9, show_copy_button=True |
|
) |
|
cost_time = gr.Markdown() |
|
with gr.Column(scale=4): |
|
result = gr.Image(label="Result", type="numpy", interactive=False) |
|
|
|
submit.click( |
|
fn=gen, |
|
inputs=[sdxl_model, dtg_model, style, base_prompt, addon_propmt, seed], |
|
outputs=[result, dtg_output, cost_time], |
|
) |
|
|
|
demo.launch() |
|
|