Spaces:
Runtime error
Runtime error
import random | |
from time import time_ns | |
import torch | |
import spaces | |
import gradio as gr | |
from transformers import set_seed | |
from kgen import models | |
from diff import load_model, encode_prompts | |
from dtg import process | |
from meta import ( | |
DEFAULT_STYLE_LIST, | |
MODEL_FORMAT_LIST, | |
MODEL_DEFAULT_QUALITY_LIST, | |
DEFAULT_NEGATIVE_PROMPT, | |
) | |
sdxl_pipe = load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda") | |
models.load_model(models.model_list[0]) | |
models.text_model.cuda() | |
current_dtg_model = models.model_list[0] | |
current_sdxl_model = "KBlueLeaf/Kohaku-XL-Epsilon" | |
def gen( | |
sdxl_model: str, | |
dtg_model: str, | |
style: str, | |
base_prompt: str, | |
addon_prompt: str = "", | |
): | |
global current_dtg_model, current_sdxl_model, sdxl_pipe | |
if sdxl_model != current_sdxl_model: | |
sdxl_pipe = load_model(model_id=sdxl_model, device="cuda") | |
current_sdxl_model = sdxl_model | |
if dtg_model != current_dtg_model: | |
models.load_model(dtg_model) | |
models.text_model.cuda() | |
current_dtg_model = dtg_model | |
t0 = time_ns() | |
seed = random.randint(0, 2**31 - 1) | |
prompt = ( | |
f"{base_prompt}, {addon_prompt}, " | |
f"{DEFAULT_STYLE_LIST[style]}, " | |
f"{MODEL_DEFAULT_QUALITY_LIST[sdxl_model]}, " | |
) | |
full_prompt = process( | |
prompt, | |
aspect_ratio=1.0, | |
seed=seed, | |
tag_length="short", | |
ban_tags=".*alternate.*, character doll, multiple.*, .*cosplay.*, .*name, .*text.*", | |
format=MODEL_FORMAT_LIST[sdxl_model], | |
temperature=1.2, | |
) | |
torch.cuda.empty_cache() | |
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( | |
encode_prompts(sdxl_pipe, full_prompt, DEFAULT_NEGATIVE_PROMPT) | |
) | |
set_seed(seed) | |
result = sdxl_pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_embeds2, | |
negative_pooled_prompt_embeds=neg_pooled_embeds2, | |
num_inference_steps=24, | |
width=1024, | |
height=1024, | |
guidance_scale=6.0, | |
).images[0] | |
torch.cuda.empty_cache() | |
t1 = time_ns() | |
return result.convert("RGB"), full_prompt, f"Cost: {(t1 - t0) / 1e9:.4}sec" | |
if __name__ == "__main__": | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("""# This Cute Dragon Girl Doesn't Exist""") | |
with gr.Accordion("Introduction and Instructions", open=False): | |
gr.Markdown( | |
""" | |
### What is this: | |
"This Cute Dragon Girl Doesn't Exist" is a Demo for KGen System(DanTagGen) with SDXL anime models. | |
It is aimed to show how the DanTagGen can be used to "refine/upsample" simple prompt to help the T2I model. | |
Since I already have some application and demo on DanTagGen. | |
This demo is designed to be more "simple" than before. | |
Just one click, and get the result with high quality and high diversity. | |
### How to use it: | |
click "Next" button until you get the dragon girl you like. | |
### Resources: | |
- My anime model: [Kohaku XL Epsilon](https://huggingface.co/KBlueLeaf/Kohaku-XL-Epsilon) | |
- DanTagGen: [DanTagGen](https://huggingface.co/KBlueLeaf/DanTagGen-beta) | |
- DanTagGen extension: [z-a1111-sd-webui-dtg](https://github.com/KohakuBlueleaf/z-a1111-sd-webui-dtg) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
sdxl_model = gr.Dropdown( | |
MODEL_FORMAT_LIST, | |
label="SDXL Model", | |
value=list(MODEL_FORMAT_LIST)[0], | |
) | |
dtg_model = gr.Dropdown( | |
models.model_list, | |
label="DTG Model", | |
value=models.model_list[0], | |
) | |
base_prompt = gr.Textbox( | |
label="Base prompt", | |
lines=1, | |
value="1girl, solo, dragon girl, dragon wings, dragon horns, dragon tail", | |
interactive=False, | |
) | |
with gr.Row(): | |
addon_propmt = gr.Textbox( | |
label="Addon prompt", | |
lines=1, | |
value="cowboy shot", | |
) | |
style = gr.Dropdown( | |
DEFAULT_STYLE_LIST, | |
label="Style", | |
value=list(DEFAULT_STYLE_LIST)[0], | |
) | |
submit = gr.Button("Next", variant="primary") | |
dtg_output = gr.TextArea( | |
label="DTG output", lines=9, show_copy_button=True | |
) | |
cost_time = gr.Markdown() | |
with gr.Column(scale=4): | |
result = gr.Image(label="Result", type="numpy", interactive=False) | |
submit.click( | |
fn=gen, | |
inputs=[sdxl_model, dtg_model, style, base_prompt, addon_propmt], | |
outputs=[result, dtg_output, cost_time], | |
) | |
demo.launch() | |