|
from typing import Callable |
|
from PIL import Image |
|
|
|
import gradio as gr |
|
|
|
from v2 import V2UI |
|
from diffusion import ImageGenerator, image_generation_config_ui |
|
from output import UpsamplingOutput |
|
from utils import ( |
|
PEOPLE_TAGS, |
|
gradio_copy_text, |
|
COPY_ACTION_JS, |
|
) |
|
|
|
|
|
NORMALIZE_RATING_TAG = { |
|
"sfw": "", |
|
"general": "", |
|
"sensitive": "sensitive", |
|
"nsfw": "nsfw", |
|
"questionable": "nsfw", |
|
"explicit": "nsfw, explicit", |
|
} |
|
|
|
|
|
def example( |
|
copyright: str, |
|
character: str, |
|
general: str, |
|
rating: str, |
|
aspect_ratio: str, |
|
length: str, |
|
identity: str, |
|
image_size: str, |
|
): |
|
return [ |
|
copyright, |
|
character, |
|
general, |
|
rating, |
|
aspect_ratio, |
|
length, |
|
identity, |
|
image_size, |
|
] |
|
|
|
|
|
GRADIO_EXAMPLES = [ |
|
example( |
|
copyright="original", |
|
character="", |
|
general="1girl, solo, upper body, :d", |
|
rating="general", |
|
aspect_ratio="tall", |
|
length="long", |
|
identity="none", |
|
image_size="768x1344", |
|
), |
|
example( |
|
copyright="original", |
|
character="", |
|
general="1girl, solo, blue theme, limited palette", |
|
rating="sfw", |
|
aspect_ratio="ultra_wide", |
|
length="long", |
|
identity="lax", |
|
image_size="1536x640", |
|
), |
|
example( |
|
copyright="", |
|
character="", |
|
general="4girls", |
|
rating="sfw", |
|
aspect_ratio="tall", |
|
length="very_long", |
|
identity="lax", |
|
image_size="768x1344", |
|
), |
|
example( |
|
copyright="original", |
|
character="", |
|
general="1girl, solo, upper body, looking at viewer, profile picture", |
|
rating="sfw", |
|
aspect_ratio="square", |
|
length="medium", |
|
identity="none", |
|
image_size="1024x1024", |
|
), |
|
example( |
|
copyright="original", |
|
character="", |
|
general="1girl, solo, chibi, upper body, looking at viewer, simple background, limited palette, square", |
|
rating="sfw", |
|
aspect_ratio="square", |
|
length="medium", |
|
identity="none", |
|
image_size="1024x1024", |
|
), |
|
example( |
|
copyright="original", |
|
character="", |
|
general="1girl, full body, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, solo, yellow shirt, simple background, green background", |
|
rating="sfw", |
|
aspect_ratio="tall", |
|
length="very_long", |
|
identity="strict", |
|
image_size="768x1344", |
|
), |
|
example( |
|
copyright="", |
|
character="", |
|
general="no humans, scenery, spring (season)", |
|
rating="general", |
|
aspect_ratio="ultra_wide", |
|
length="medium", |
|
identity="lax", |
|
image_size="1536x640", |
|
), |
|
example( |
|
copyright="", |
|
character="", |
|
general="no humans, cyberpunk, city, cityscape, building, neon lights, pixel art", |
|
rating="general", |
|
aspect_ratio="ultra_wide", |
|
length="medium", |
|
identity="lax", |
|
image_size="1536x640", |
|
), |
|
example( |
|
copyright="sousou no frieren", |
|
character="frieren", |
|
general="1girl, solo", |
|
rating="general", |
|
aspect_ratio="tall", |
|
length="long", |
|
identity="lax", |
|
image_size="768x1344", |
|
), |
|
example( |
|
copyright="honkai: star rail", |
|
character="firefly (honkai: star rail)", |
|
general="1girl, solo", |
|
rating="sfw", |
|
aspect_ratio="tall", |
|
length="medium", |
|
identity="lax", |
|
image_size="768x1344", |
|
), |
|
example( |
|
copyright="honkai: star rail", |
|
character="silver wolf (honkai: star rail)", |
|
general="1girl, solo, annoyed", |
|
rating="sfw", |
|
aspect_ratio="tall", |
|
length="long", |
|
identity="lax", |
|
image_size="768x1344", |
|
), |
|
example( |
|
copyright="chuunibyou demo koi ga shitai!", |
|
character="takanashi rikka", |
|
general="1girl, solo", |
|
rating="sfw", |
|
aspect_ratio="ultra_tall", |
|
length="medium", |
|
identity="lax", |
|
image_size="640x1536", |
|
), |
|
] |
|
|
|
|
|
def animagine_xl_v3_1(output: UpsamplingOutput): |
|
|
|
people_tags = [] |
|
other_general_tags = [] |
|
for tag in output.general_tags.split(","): |
|
tag = tag.strip() |
|
if tag in PEOPLE_TAGS: |
|
people_tags.append(tag) |
|
else: |
|
other_general_tags.append(tag) |
|
|
|
return ", ".join( |
|
[ |
|
part.strip() |
|
for part in [ |
|
*people_tags, |
|
output.character_tags, |
|
output.copyright_tags, |
|
*other_general_tags, |
|
output.upsampled_tags, |
|
NORMALIZE_RATING_TAG[output.rating_tag], |
|
] |
|
if part.strip() != "" |
|
] |
|
) |
|
|
|
|
|
def elapsed_time_format(elapsed_time: float) -> str: |
|
return f"Elapsed: {elapsed_time:.2f} seconds" |
|
|
|
|
|
def parse_upsampling_output( |
|
upsampler: Callable[..., UpsamplingOutput], |
|
): |
|
def _parse_upsampling_output(*args) -> tuple[str, str, dict, dict]: |
|
output = upsampler(*args) |
|
|
|
print(output) |
|
|
|
return ( |
|
animagine_xl_v3_1(output), |
|
elapsed_time_format(output.elapsed_time), |
|
gr.update( |
|
interactive=True, |
|
), |
|
gr.update( |
|
interactive=True, |
|
), |
|
) |
|
|
|
return _parse_upsampling_output |
|
|
|
|
|
def description_ui(): |
|
gr.Markdown( |
|
""" |
|
# Danbooru Tags Transformer V2 Demo |
|
|
|
Models: |
|
- [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft) (Mixtral architecture) |
|
- [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft) (Mistral architecture) |
|
- [Animagine XL v3.1](https://huggingface.co/cagliostrolab/animagine-xl-3.1) (Image generation model) |
|
""" |
|
) |
|
|
|
|
|
def main(): |
|
|
|
v2 = V2UI() |
|
|
|
print("Loading diffusion model...") |
|
image_generator = ImageGenerator() |
|
print("Loaded.") |
|
|
|
with gr.Blocks() as ui: |
|
description_ui() |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
v2.ui() |
|
|
|
with gr.Column(): |
|
generate_btn = gr.Button(value="Generate tags", variant="primary") |
|
|
|
with gr.Group(): |
|
output_text = gr.TextArea(label="Output tags", interactive=False) |
|
copy_btn = gr.Button( |
|
value="Copy to clipboard", |
|
interactive=False, |
|
) |
|
|
|
elapsed_time_md = gr.Markdown(label="Elapsed time", value="") |
|
|
|
generate_image_btn = gr.Button( |
|
value="Generate image with this prompt!", |
|
interactive=False, |
|
) |
|
|
|
accordion, image_generation_config_components = ( |
|
image_generation_config_ui() |
|
) |
|
|
|
output_image = gr.Gallery( |
|
label="Generated image", |
|
show_label=True, |
|
columns=1, |
|
preview=True, |
|
visible=True, |
|
) |
|
|
|
gr.Examples( |
|
examples=GRADIO_EXAMPLES, |
|
inputs=[ |
|
*v2.get_inputs()[1:8], |
|
image_generation_config_components[0], |
|
], |
|
) |
|
|
|
generate_btn.click( |
|
parse_upsampling_output(v2.on_generate), |
|
inputs=[ |
|
*v2.get_inputs(), |
|
], |
|
outputs=[output_text, elapsed_time_md, copy_btn, generate_image_btn], |
|
) |
|
copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS) |
|
generate_image_btn.click( |
|
image_generator.generate, |
|
inputs=[output_text, *image_generation_config_components], |
|
outputs=[output_image], |
|
) |
|
|
|
ui.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|