TIPO-DEMO / app.py
KBlueLeaf's picture
Update app.py
1b4fd22 verified
raw
history blame
12.8 kB
import sys, os
import gradio as gr
## if kgen not exist
try:
import kgen
except:
GH_TOKEN = os.getenv("GITHUB_TOKEN")
git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TITPOP-KGen@titpop"
## call pip install
os.system(f"pip install git+{git_url}")
import re
import random
from time import time
import torch
from transformers import set_seed
if sys.platform == "win32":
#dev env in windows, @spaces.GPU will cause problem
def GPU(func, *args, **kwargs):
return func
else:
from spaces import GPU
import kgen.models as models
import kgen.executor.titpop as titpop
from kgen.formatter import seperate_tags, apply_format
from kgen.generate import generate
from diff import load_model, encode_prompts
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
sdxl_pipe = load_model()
models.load_model(
"Amber-River/titpop",
device="cuda",
subfolder="500M-epoch3",
)
generate(max_new_tokens=4)
DEFAULT_TAGS = """
1girl, king halo (umamusume), umamusume,
ningen mame, ciloranko, ogipote, misu kasumi,
solo, leaning forward, sky,
masterpiece, absurdres, sensitive, newest
""".strip()
DEFAULT_NL = """
An illustration of a girl
""".strip()
def format_time(timing):
total = timing["total"]
generate_pass = timing["generate_pass"]
result = ""
result += f"""
### Process Time
| Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second|
|-|-|-|
"""
if "generated_tokens" in timing:
total_generated_tokens = timing["generated_tokens"]
total_input_tokens = timing["input_tokens"]
if "generated_tokens" in timing and "total_sampling" in timing:
sampling_time = timing["total_sampling"] / 1000
process_time = timing["prompt_process"] / 1000
model_time = timing["total_eval"] / 1000
result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second|
| Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second|
| Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second|
"""
if "generated_tokens" in timing:
result += f"""
### Processed Tokens:
* {total_input_tokens:} Input Tokens
* {total_generated_tokens:} Output Tokens
"""
return result
@GPU(duration=10)
@torch.no_grad()
def generate(
tags,
nl_prompt,
black_list,
temp,
output_format,
target_length,
top_p,
min_p,
top_k,
seed,
escape_brackets,
):
default_format = DEFAULT_FORMAT[output_format]
titpop.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()]
generation_setting = {
"seed": seed,
"temperature": temp,
"top_p": top_p,
"min_p": min_p,
"top_k": top_k,
}
inputs = seperate_tags(tags.split(","))
if nl_prompt:
if "<|extended|>" in default_format:
inputs["extended"] = nl_prompt
elif "<|generated|>" in default_format:
inputs["generated"] = nl_prompt
input_prompt = apply_format(inputs, default_format)
if escape_brackets:
input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt)
meta, operations, general, nl_prompt = titpop.parse_titpop_request(
seperate_tags(tags.split(",")),
nl_prompt,
tag_length_target=target_length,
generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt,
)
t0 = time()
for result, timing in titpop.titpop_runner_generator(
meta, operations, general, nl_prompt, **generation_setting
):
result = apply_format(result, default_format)
if escape_brackets:
result = re.sub(r"([()\[\]])", r"\\\1", result)
timing["total"] = time() - t0
yield result, input_prompt, format_time(timing)
@GPU(duration=20)
@torch.no_grad()
def generate_image(
seed,
prompt,
prompt2,
):
torch.cuda.empty_cache()
set_seed(seed)
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
)
result2 = sdxl_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_embeds2,
negative_pooled_prompt_embeds=neg_pooled_embeds2,
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
yield result2, None
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
)
set_seed(seed)
result = sdxl_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_embeds2,
negative_pooled_prompt_embeds=neg_pooled_embeds2,
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
torch.cuda.empty_cache()
yield result2, result
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Accordion("Introduction and Instructions", open=False):
gr.Markdown(
"""
## TITPOP Demo
**The model for demo is 500M version with 4epoch training (25B token seen)**
### What is this
TITPOP is a tool to extend, generate, refine the input prompt for T2I models.
<br>It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models.
<br>You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2)
### How to use this demo
1. Enter your tags(optional): put the desired tags into "danboru tags" box
2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box
3. Enter your black list(optional): put the desired black list into "black list" box
4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ...
4. Click "TITPOP" button: you will see refined prompt on "result" box
5. If you like the result, click "Generate Image From Result" button
* You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt
* The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False
### Why inference code is private? When will it be open sourced?
1. This model/tool is still under development, currently is early Alpha version.
2. I'm doing some research and projects based on this.
3. The 200M model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself.
4. Once the project/research are done, I will open source all these models/codes with Apache2 license.
### Notification
**TITPOP is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
<br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model**
"""
)
with gr.Row():
with gr.Column(scale=5):
with gr.Row():
with gr.Column(scale=3):
tags_input = gr.TextArea(
label="Danbooru Tags",
lines=7,
show_copy_button=True,
interactive=True,
value=DEFAULT_TAGS,
placeholder="Enter danbooru tags here",
)
nl_prompt_input = gr.Textbox(
label="Natural Language Prompt",
lines=7,
show_copy_button=True,
interactive=True,
value=DEFAULT_NL,
placeholder="Enter Natural Language Prompt here",
)
black_list = gr.TextArea(
label="Black List (seperated by comma)",
lines=4,
interactive=True,
value="monochrome",
placeholder="Enter tag/nl black list here",
)
with gr.Column(scale=2):
output_format = gr.Dropdown(
label="Output Format",
choices=list(DEFAULT_FORMAT.keys()),
value="Both, tag first (recommend)"
)
target_length = gr.Dropdown(
label="Target Length",
choices=["very_short", "short", "long", "very_long"],
value="long",
)
temp = gr.Slider(
label="Temp",
minimum=0.0,
maximum=1.5,
value=0.5,
step=0.05,
)
top_p = gr.Slider(
label="Top P",
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.05,
)
min_p = gr.Slider(
label="Min P",
minimum=0.0,
maximum=0.2,
value=0.05,
step=0.01,
)
top_k = gr.Slider(
label="Top K", minimum=0, maximum=120, value=60, step=1
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=2147483647,
value=20090220,
step=1,
)
escape_brackets = gr.Checkbox(
label="Escape Brackets", value=False
)
submit = gr.Button("TITPOP!", variant="primary")
with gr.Accordion("Speed statstics", open=False):
cost_time = gr.Markdown()
with gr.Column(scale=5):
result = gr.TextArea(
label="Result", lines=8, show_copy_button=True, interactive=False
)
input_prompt = gr.Textbox(
label="Input Prompt", lines=1, interactive=False, visible=False
)
gen_img = gr.Button("Generate Image from Result", variant="primary", interactive=False)
with gr.Row():
with gr.Column():
img1 = gr.Image(label="Original Prompt", interactive=False)
with gr.Column():
img2 = gr.Image(label="Generated Prompt", interactive=False)
def generate_wrapper(*args):
yield "", "", "", gr.update(interactive=False),
for i in generate(*args):
yield *i, gr.update(interactive=False)
yield *i, gr.update(interactive=True)
submit.click(
generate_wrapper,
[
tags_input,
nl_prompt_input,
black_list,
temp,
output_format,
target_length,
top_p,
min_p,
top_k,
seed,
escape_brackets,
],
[
result,
input_prompt,
cost_time,
gen_img,
],
queue=True,
)
def generate_image_wrapper(seed, result, input_prompt):
for img1, img2 in generate_image(seed, result, input_prompt):
yield img1, img2, gr.update(interactive=False)
yield img1, img2, gr.update(interactive=True)
gen_img.click(
generate_image_wrapper,
[seed, result, input_prompt],
[img1, img2, submit],
queue=True,
)
gen_img.click(
lambda *args: gr.update(interactive=False),
None,
[submit],
queue=False,
)
demo.launch()