|
import sys, os |
|
import gradio as gr |
|
|
|
|
|
try: |
|
import kgen |
|
except: |
|
GH_TOKEN = os.getenv("GITHUB_TOKEN") |
|
git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TITPOP-KGen@titpop" |
|
|
|
|
|
os.system(f"pip install git+{git_url}") |
|
|
|
import re |
|
import random |
|
from time import time |
|
|
|
import torch |
|
from transformers import set_seed |
|
if sys.platform == "win32": |
|
|
|
def GPU(func, *args, **kwargs): |
|
return func |
|
else: |
|
from spaces import GPU |
|
|
|
import kgen.models as models |
|
import kgen.executor.titpop as titpop |
|
from kgen.formatter import seperate_tags, apply_format |
|
from kgen.generate import generate |
|
|
|
from diff import load_model, encode_prompts |
|
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT |
|
|
|
|
|
sdxl_pipe = load_model() |
|
|
|
models.load_model( |
|
"Amber-River/titpop", |
|
device="cuda", |
|
subfolder="500M-epoch3", |
|
) |
|
generate(max_new_tokens=4) |
|
DEFAULT_TAGS = """ |
|
1girl, king halo (umamusume), umamusume, |
|
ningen mame, ciloranko, ogipote, misu kasumi, |
|
solo, leaning forward, sky, |
|
masterpiece, absurdres, sensitive, newest |
|
""".strip() |
|
DEFAULT_NL = """ |
|
An illustration of a girl |
|
""".strip() |
|
|
|
|
|
def format_time(timing): |
|
total = timing["total"] |
|
generate_pass = timing["generate_pass"] |
|
|
|
result = "" |
|
|
|
result += f""" |
|
### Process Time |
|
| Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second| |
|
|-|-|-| |
|
""" |
|
if "generated_tokens" in timing: |
|
total_generated_tokens = timing["generated_tokens"] |
|
total_input_tokens = timing["input_tokens"] |
|
if "generated_tokens" in timing and "total_sampling" in timing: |
|
sampling_time = timing["total_sampling"] / 1000 |
|
process_time = timing["prompt_process"] / 1000 |
|
model_time = timing["total_eval"] / 1000 |
|
|
|
result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second| |
|
| Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second| |
|
| Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second| |
|
""" |
|
|
|
if "generated_tokens" in timing: |
|
result += f""" |
|
### Processed Tokens: |
|
* {total_input_tokens:} Input Tokens |
|
* {total_generated_tokens:} Output Tokens |
|
""" |
|
return result |
|
|
|
|
|
@GPU(duration=10) |
|
@torch.no_grad() |
|
def generate( |
|
tags, |
|
nl_prompt, |
|
black_list, |
|
temp, |
|
output_format, |
|
target_length, |
|
top_p, |
|
min_p, |
|
top_k, |
|
seed, |
|
escape_brackets, |
|
): |
|
default_format = DEFAULT_FORMAT[output_format] |
|
titpop.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()] |
|
generation_setting = { |
|
"seed": seed, |
|
"temperature": temp, |
|
"top_p": top_p, |
|
"min_p": min_p, |
|
"top_k": top_k, |
|
} |
|
inputs = seperate_tags(tags.split(",")) |
|
if nl_prompt: |
|
if "<|extended|>" in default_format: |
|
inputs["extended"] = nl_prompt |
|
elif "<|generated|>" in default_format: |
|
inputs["generated"] = nl_prompt |
|
input_prompt = apply_format(inputs, default_format) |
|
if escape_brackets: |
|
input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt) |
|
|
|
meta, operations, general, nl_prompt = titpop.parse_titpop_request( |
|
seperate_tags(tags.split(",")), |
|
nl_prompt, |
|
tag_length_target=target_length, |
|
generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt, |
|
) |
|
t0 = time() |
|
for result, timing in titpop.titpop_runner_generator( |
|
meta, operations, general, nl_prompt, **generation_setting |
|
): |
|
result = apply_format(result, default_format) |
|
if escape_brackets: |
|
result = re.sub(r"([()\[\]])", r"\\\1", result) |
|
timing["total"] = time() - t0 |
|
yield result, input_prompt, format_time(timing) |
|
|
|
|
|
@GPU(duration=20) |
|
@torch.no_grad() |
|
def generate_image( |
|
seed, |
|
prompt, |
|
prompt2, |
|
): |
|
torch.cuda.empty_cache() |
|
set_seed(seed) |
|
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( |
|
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT) |
|
) |
|
result2 = sdxl_pipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_embeds2, |
|
negative_pooled_prompt_embeds=neg_pooled_embeds2, |
|
num_inference_steps=24, |
|
width=1024, |
|
height=1024, |
|
guidance_scale=6.0, |
|
).images[0] |
|
yield result2, None |
|
|
|
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( |
|
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT) |
|
) |
|
set_seed(seed) |
|
result = sdxl_pipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_embeds2, |
|
negative_pooled_prompt_embeds=neg_pooled_embeds2, |
|
num_inference_steps=24, |
|
width=1024, |
|
height=1024, |
|
guidance_scale=6.0, |
|
).images[0] |
|
torch.cuda.empty_cache() |
|
yield result2, result |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
with gr.Accordion("Introduction and Instructions", open=False): |
|
gr.Markdown( |
|
""" |
|
## TITPOP Demo |
|
**The model for demo is 500M version with 4epoch training (25B token seen)** |
|
|
|
### What is this |
|
TITPOP is a tool to extend, generate, refine the input prompt for T2I models. |
|
<br>It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models. |
|
<br>You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2) |
|
|
|
### How to use this demo |
|
1. Enter your tags(optional): put the desired tags into "danboru tags" box |
|
2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box |
|
3. Enter your black list(optional): put the desired black list into "black list" box |
|
4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ... |
|
4. Click "TITPOP" button: you will see refined prompt on "result" box |
|
5. If you like the result, click "Generate Image From Result" button |
|
* You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt |
|
* The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False |
|
|
|
### Why inference code is private? When will it be open sourced? |
|
1. This model/tool is still under development, currently is early Alpha version. |
|
2. I'm doing some research and projects based on this. |
|
3. The 200M model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself. |
|
4. Once the project/research are done, I will open source all these models/codes with Apache2 license. |
|
|
|
### Notification |
|
**TITPOP is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model. |
|
<br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model** |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
tags_input = gr.TextArea( |
|
label="Danbooru Tags", |
|
lines=7, |
|
show_copy_button=True, |
|
interactive=True, |
|
value=DEFAULT_TAGS, |
|
placeholder="Enter danbooru tags here", |
|
) |
|
nl_prompt_input = gr.Textbox( |
|
label="Natural Language Prompt", |
|
lines=7, |
|
show_copy_button=True, |
|
interactive=True, |
|
value=DEFAULT_NL, |
|
placeholder="Enter Natural Language Prompt here", |
|
) |
|
black_list = gr.TextArea( |
|
label="Black List (seperated by comma)", |
|
lines=4, |
|
interactive=True, |
|
value="monochrome", |
|
placeholder="Enter tag/nl black list here", |
|
) |
|
with gr.Column(scale=2): |
|
output_format = gr.Dropdown( |
|
label="Output Format", |
|
choices=list(DEFAULT_FORMAT.keys()), |
|
value="Both, tag first (recommend)" |
|
) |
|
target_length = gr.Dropdown( |
|
label="Target Length", |
|
choices=["very_short", "short", "long", "very_long"], |
|
value="long", |
|
) |
|
temp = gr.Slider( |
|
label="Temp", |
|
minimum=0.0, |
|
maximum=1.5, |
|
value=0.5, |
|
step=0.05, |
|
) |
|
top_p = gr.Slider( |
|
label="Top P", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
) |
|
min_p = gr.Slider( |
|
label="Min P", |
|
minimum=0.0, |
|
maximum=0.2, |
|
value=0.05, |
|
step=0.01, |
|
) |
|
top_k = gr.Slider( |
|
label="Top K", minimum=0, maximum=120, value=60, step=1 |
|
) |
|
with gr.Row(): |
|
seed = gr.Number( |
|
label="Seed", |
|
minimum=0, |
|
maximum=2147483647, |
|
value=20090220, |
|
step=1, |
|
) |
|
escape_brackets = gr.Checkbox( |
|
label="Escape Brackets", value=False |
|
) |
|
submit = gr.Button("TITPOP!", variant="primary") |
|
with gr.Accordion("Speed statstics", open=False): |
|
cost_time = gr.Markdown() |
|
with gr.Column(scale=5): |
|
result = gr.TextArea( |
|
label="Result", lines=8, show_copy_button=True, interactive=False |
|
) |
|
input_prompt = gr.Textbox( |
|
label="Input Prompt", lines=1, interactive=False, visible=False |
|
) |
|
gen_img = gr.Button("Generate Image from Result", variant="primary", interactive=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
img1 = gr.Image(label="Original Prompt", interactive=False) |
|
with gr.Column(): |
|
img2 = gr.Image(label="Generated Prompt", interactive=False) |
|
def generate_wrapper(*args): |
|
yield "", "", "", gr.update(interactive=False), |
|
for i in generate(*args): |
|
yield *i, gr.update(interactive=False) |
|
yield *i, gr.update(interactive=True) |
|
submit.click( |
|
generate_wrapper, |
|
[ |
|
tags_input, |
|
nl_prompt_input, |
|
black_list, |
|
temp, |
|
output_format, |
|
target_length, |
|
top_p, |
|
min_p, |
|
top_k, |
|
seed, |
|
escape_brackets, |
|
], |
|
[ |
|
result, |
|
input_prompt, |
|
cost_time, |
|
gen_img, |
|
], |
|
queue=True, |
|
) |
|
|
|
def generate_image_wrapper(seed, result, input_prompt): |
|
for img1, img2 in generate_image(seed, result, input_prompt): |
|
yield img1, img2, gr.update(interactive=False) |
|
yield img1, img2, gr.update(interactive=True) |
|
gen_img.click( |
|
generate_image_wrapper, |
|
[seed, result, input_prompt], |
|
[img1, img2, submit], |
|
queue=True, |
|
) |
|
gen_img.click( |
|
lambda *args: gr.update(interactive=False), |
|
None, |
|
[submit], |
|
queue=False, |
|
) |
|
|
|
demo.launch() |
|
|