Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse files- .gitignore +176 -0
- README.md +3 -3
- app.py +185 -0
- diffusion.py +71 -0
- output.py +16 -0
- requirements.txt +5 -0
- utils.py +44 -0
- v2.py +254 -0
.gitignore
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python
|
2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
3 |
+
|
4 |
+
### Python ###
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
165 |
+
|
166 |
+
### Python Patch ###
|
167 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
168 |
+
poetry.toml
|
169 |
+
|
170 |
+
# ruff
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# LSP config files
|
174 |
+
pyrightconfig.json
|
175 |
+
|
176 |
+
# End of https://www.toptal.com/developers/gitignore/api/python
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Danbooru Tags Transformer V2
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.28.3
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: Danbooru Tags Transformer V2
|
3 |
+
emoji: 📦
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.28.3
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from v2 import V2UI
|
7 |
+
from diffusion import ImageGenerator
|
8 |
+
from output import UpsamplingOutput
|
9 |
+
from utils import QUALITY_TAGS, NEGATIVE_PROMPT, IMAGE_SIZE_OPTIONS, IMAGE_SIZES
|
10 |
+
|
11 |
+
|
12 |
+
def animagine_xl_v3_1(output: UpsamplingOutput):
|
13 |
+
return ", ".join(
|
14 |
+
[
|
15 |
+
part.strip()
|
16 |
+
for part in [
|
17 |
+
output.character_tags,
|
18 |
+
output.copyright_tags,
|
19 |
+
output.general_tags,
|
20 |
+
output.upsampled_tags,
|
21 |
+
(
|
22 |
+
output.rating_tag
|
23 |
+
if output.rating_tag not in ["<|rating:sfw|>", "<|rating:general|>"]
|
24 |
+
else ""
|
25 |
+
),
|
26 |
+
]
|
27 |
+
if part.strip() != ""
|
28 |
+
]
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def elapsed_time_format(elapsed_time: float) -> str:
|
33 |
+
return f"Elapsed: {elapsed_time:.2f} seconds"
|
34 |
+
|
35 |
+
|
36 |
+
def parse_upsampling_output(
|
37 |
+
upsampler: Callable[..., UpsamplingOutput],
|
38 |
+
image_generator: Callable[..., Image.Image],
|
39 |
+
):
|
40 |
+
def _parse_upsampling_output(
|
41 |
+
generate_image: bool, *args
|
42 |
+
) -> tuple[str, str, Image.Image | None]:
|
43 |
+
output = upsampler(*args)
|
44 |
+
|
45 |
+
print(output)
|
46 |
+
|
47 |
+
if not generate_image:
|
48 |
+
return (
|
49 |
+
animagine_xl_v3_1(output),
|
50 |
+
elapsed_time_format(output.elapsed_time),
|
51 |
+
None,
|
52 |
+
)
|
53 |
+
|
54 |
+
# generate image
|
55 |
+
[
|
56 |
+
image_size_option,
|
57 |
+
quality_tags,
|
58 |
+
negative_prompt,
|
59 |
+
num_inference_steps,
|
60 |
+
guidance_scale,
|
61 |
+
] = args[
|
62 |
+
7:
|
63 |
+
] # remove the first 7 arguments for upsampler
|
64 |
+
width, height = IMAGE_SIZES[image_size_option]
|
65 |
+
image = image_generator(
|
66 |
+
", ".join([animagine_xl_v3_1(output), quality_tags]),
|
67 |
+
negative_prompt,
|
68 |
+
height,
|
69 |
+
width,
|
70 |
+
num_inference_steps,
|
71 |
+
guidance_scale,
|
72 |
+
)
|
73 |
+
|
74 |
+
return (
|
75 |
+
animagine_xl_v3_1(output),
|
76 |
+
elapsed_time_format(output.elapsed_time),
|
77 |
+
image,
|
78 |
+
)
|
79 |
+
|
80 |
+
return _parse_upsampling_output
|
81 |
+
|
82 |
+
|
83 |
+
def toggle_visible_output_image(generate_image: bool):
|
84 |
+
return gr.update(
|
85 |
+
visible=generate_image,
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
def image_generation_config_ui():
|
90 |
+
with gr.Accordion(label="Image generation config", open=True) as accordion:
|
91 |
+
image_size = gr.Radio(
|
92 |
+
label="Image size",
|
93 |
+
choices=list(IMAGE_SIZE_OPTIONS.keys()),
|
94 |
+
value=list(IMAGE_SIZE_OPTIONS.keys())[3], # tall
|
95 |
+
)
|
96 |
+
|
97 |
+
quality_tags = gr.Textbox(
|
98 |
+
label="Quality tags",
|
99 |
+
placeholder=QUALITY_TAGS["default"],
|
100 |
+
value=QUALITY_TAGS["default"],
|
101 |
+
)
|
102 |
+
negative_prompt = gr.Textbox(
|
103 |
+
label="Negative prompt",
|
104 |
+
placeholder=NEGATIVE_PROMPT["default"],
|
105 |
+
value=NEGATIVE_PROMPT["default"],
|
106 |
+
)
|
107 |
+
|
108 |
+
num_inference_steps = gr.Slider(
|
109 |
+
label="Num inference steps",
|
110 |
+
minimum=20,
|
111 |
+
maximum=30,
|
112 |
+
step=1,
|
113 |
+
value=25,
|
114 |
+
)
|
115 |
+
guidance_scale = gr.Slider(
|
116 |
+
label="Guidance scale",
|
117 |
+
minimum=0.0,
|
118 |
+
maximum=10.0,
|
119 |
+
step=0.5,
|
120 |
+
value=7.0,
|
121 |
+
)
|
122 |
+
|
123 |
+
return accordion, [
|
124 |
+
image_size,
|
125 |
+
quality_tags,
|
126 |
+
negative_prompt,
|
127 |
+
num_inference_steps,
|
128 |
+
guidance_scale,
|
129 |
+
]
|
130 |
+
|
131 |
+
|
132 |
+
def main():
|
133 |
+
|
134 |
+
v2 = V2UI()
|
135 |
+
|
136 |
+
print("Loading diffusion model...")
|
137 |
+
image_generator = ImageGenerator()
|
138 |
+
print("Loaded.")
|
139 |
+
|
140 |
+
with gr.Blocks() as ui:
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column():
|
143 |
+
v2.ui()
|
144 |
+
|
145 |
+
generate_image_check = gr.Checkbox(
|
146 |
+
label="Also generate image", value=True
|
147 |
+
)
|
148 |
+
|
149 |
+
accordion, image_generation_config_components = (
|
150 |
+
image_generation_config_ui()
|
151 |
+
)
|
152 |
+
|
153 |
+
with gr.Column():
|
154 |
+
output_text = gr.TextArea(label="Output tags", interactive=False)
|
155 |
+
|
156 |
+
elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
|
157 |
+
|
158 |
+
output_image = gr.Gallery(
|
159 |
+
label="Output image",
|
160 |
+
columns=1,
|
161 |
+
preview=True,
|
162 |
+
show_label=False,
|
163 |
+
visible=True,
|
164 |
+
)
|
165 |
+
|
166 |
+
v2.get_generate_btn().click(
|
167 |
+
parse_upsampling_output(v2.on_generate, image_generator.generate),
|
168 |
+
inputs=[
|
169 |
+
generate_image_check,
|
170 |
+
*v2.get_inputs(),
|
171 |
+
*image_generation_config_components,
|
172 |
+
],
|
173 |
+
outputs=[output_text, elapsed_time_md, output_image],
|
174 |
+
)
|
175 |
+
generate_image_check.change(
|
176 |
+
toggle_visible_output_image,
|
177 |
+
inputs=[generate_image_check],
|
178 |
+
outputs=[output_image],
|
179 |
+
)
|
180 |
+
|
181 |
+
ui.launch()
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|
diffusion.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
|
5 |
+
StableDiffusionXLPipeline,
|
6 |
+
)
|
7 |
+
from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
|
8 |
+
EulerAncestralDiscreteScheduler,
|
9 |
+
)
|
10 |
+
|
11 |
+
try:
|
12 |
+
import spaces
|
13 |
+
except ImportError:
|
14 |
+
|
15 |
+
class spaces:
|
16 |
+
def GPU(*args, **kwargs):
|
17 |
+
return lambda x: x
|
18 |
+
|
19 |
+
|
20 |
+
from utils import NEGATIVE_PROMPT
|
21 |
+
|
22 |
+
|
23 |
+
class ImageGenerator:
|
24 |
+
pipe: StableDiffusionXLPipeline
|
25 |
+
|
26 |
+
def __init__(self, model_name: str = "cagliostrolab/animagine-xl-3.1"):
|
27 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
28 |
+
model_name,
|
29 |
+
torch_dtype=torch.float16,
|
30 |
+
custom_pipeline="lpw_stable_diffusion_xl",
|
31 |
+
use_safetensors=True,
|
32 |
+
add_watermarker=False,
|
33 |
+
)
|
34 |
+
self.pipe.bad_punct_regexscheduler = (
|
35 |
+
EulerAncestralDiscreteScheduler.from_pretrained(
|
36 |
+
model_name,
|
37 |
+
subfolder="scheduler",
|
38 |
+
)
|
39 |
+
)
|
40 |
+
|
41 |
+
# xformers
|
42 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
43 |
+
|
44 |
+
self.pipe.to("cuda")
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
@spaces.GPU(duration=30)
|
48 |
+
def generate(
|
49 |
+
self,
|
50 |
+
prompt: str,
|
51 |
+
negative_prompt: str = NEGATIVE_PROMPT["default"], # Light v3.1
|
52 |
+
height: int = 1152,
|
53 |
+
width: int = 896,
|
54 |
+
num_inference_steps: int = 25,
|
55 |
+
guidance_scale: float = 7.0,
|
56 |
+
) -> Image.Image:
|
57 |
+
print("prompt", prompt)
|
58 |
+
print("negative_prompt", negative_prompt)
|
59 |
+
print("height", height)
|
60 |
+
print("width", width)
|
61 |
+
print("num_inference_steps", num_inference_steps)
|
62 |
+
print("guidance_scale", guidance_scale)
|
63 |
+
|
64 |
+
return self.pipe(
|
65 |
+
prompt=prompt,
|
66 |
+
negative_prompt=negative_prompt,
|
67 |
+
height=height,
|
68 |
+
width=width,
|
69 |
+
num_inference_steps=num_inference_steps,
|
70 |
+
guidance_scale=guidance_scale,
|
71 |
+
).images[0]
|
output.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class UpsamplingOutput:
|
6 |
+
upsampled_tags: str
|
7 |
+
|
8 |
+
copyright_tags: str
|
9 |
+
character_tags: str
|
10 |
+
general_tags: str
|
11 |
+
rating_tag: str
|
12 |
+
aspect_ratio_tag: str
|
13 |
+
length_tag: str
|
14 |
+
identity_tag: str
|
15 |
+
|
16 |
+
elapsed_time: float = 0.0
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.0
|
2 |
+
accelerate==0.29.2
|
3 |
+
transformers==4.38.2
|
4 |
+
optimum[onnxruntime]==1.19.1
|
5 |
+
spaces==0.26.2
|
utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://huggingface.co/spaces/cagliostrolab/animagine-xl-3.1/blob/main/config.py
|
2 |
+
QUALITY_TAGS = {
|
3 |
+
"default": "(masterpiece), best quality, very aesthetic, perfect face",
|
4 |
+
}
|
5 |
+
NEGATIVE_PROMPT = {
|
6 |
+
"default": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
7 |
+
}
|
8 |
+
|
9 |
+
|
10 |
+
IMAGE_SIZE_OPTIONS = {
|
11 |
+
"1536x640": "<|aspect_ratio:ultra_wide|>",
|
12 |
+
"1216x832": "<|aspect_ratio:wide|>",
|
13 |
+
"1024x1024": "<|aspect_ratio:square|>",
|
14 |
+
"832x1216": "<|aspect_ratio:tall|>",
|
15 |
+
"640x1536": "<|aspect_ratio:ultra_tall|>",
|
16 |
+
}
|
17 |
+
IMAGE_SIZES = {
|
18 |
+
"1536x640": (1536, 640),
|
19 |
+
"1216x832": (1216, 832),
|
20 |
+
"1024x1024": (1024, 1024),
|
21 |
+
"832x1216": (832, 1216),
|
22 |
+
"640x1536": (640, 1536),
|
23 |
+
}
|
24 |
+
|
25 |
+
RATING_OPTIONS = {
|
26 |
+
"sfw": "<|rating:sfw|>",
|
27 |
+
"general": "<|rating:general|>",
|
28 |
+
"sensitive": "<|rating:sensitive|>",
|
29 |
+
"nsfw": "<|rating:nsfw|>",
|
30 |
+
"questionable": "<|rating:questionable|>",
|
31 |
+
"explicit": "<|rating:explicit|>",
|
32 |
+
}
|
33 |
+
LENGTH_OPTIONS = {
|
34 |
+
"very_short": "<|length:very_short|>",
|
35 |
+
"short": "<|length:short|>",
|
36 |
+
"medium": "<|length:medium|>",
|
37 |
+
"long": "<|length:long|>",
|
38 |
+
"very_long": "<|length:very_long|>",
|
39 |
+
}
|
40 |
+
IDENTITY_OPTIONS = {
|
41 |
+
"none": "<|identity:none|>",
|
42 |
+
"lax": "<|identity:lax|>",
|
43 |
+
"strict": "<|identity:strict|>",
|
44 |
+
}
|
v2.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from gradio.components import Component
|
8 |
+
|
9 |
+
try:
|
10 |
+
import spaces
|
11 |
+
except ImportError:
|
12 |
+
|
13 |
+
class spaces:
|
14 |
+
def GPU(*args, **kwargs):
|
15 |
+
return lambda x: x
|
16 |
+
|
17 |
+
|
18 |
+
from output import UpsamplingOutput
|
19 |
+
from utils import IMAGE_SIZE_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
|
20 |
+
|
21 |
+
ALL_MODELS = {
|
22 |
+
"dart-v2-llama-100m-sft": {
|
23 |
+
"repo": "p1atdev/dart-v2-llama-100m-sft",
|
24 |
+
"type": "sft",
|
25 |
+
},
|
26 |
+
"dart-v2-mistral-100m-sft": {
|
27 |
+
"repo": "p1atdev/dart-v2-mistral-100m-sft",
|
28 |
+
"type": "sft",
|
29 |
+
},
|
30 |
+
"dart-v2-mixtral-160m-sft": {
|
31 |
+
"repo": "p1atdev/dart-v2-mixtral-160m-sft",
|
32 |
+
"type": "sft",
|
33 |
+
},
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
def prepare_models(model_name: str):
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_name,
|
41 |
+
torch_dtype=torch.bfloat16,
|
42 |
+
device_map="auto",
|
43 |
+
)
|
44 |
+
|
45 |
+
return {
|
46 |
+
"tokenizer": tokenizer,
|
47 |
+
"model": model,
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
|
52 |
+
"""Just remove unk tokens."""
|
53 |
+
return ", ".join(
|
54 |
+
tokenizer.batch_decode(
|
55 |
+
[
|
56 |
+
token
|
57 |
+
for token in tokenizer.encode_plus(
|
58 |
+
tags,
|
59 |
+
return_tensors="pt",
|
60 |
+
).input_ids[0]
|
61 |
+
if int(token) != tokenizer.unk_token_id
|
62 |
+
],
|
63 |
+
skip_special_tokens=True,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def compose_prompt(
|
69 |
+
copyright: str = "",
|
70 |
+
character: str = "",
|
71 |
+
general: str = "",
|
72 |
+
rating: str = "<|rating:sfw|>",
|
73 |
+
aspect_ratio: str = "<|aspect_ratio:tall|>",
|
74 |
+
length: str = "<|length:long|>",
|
75 |
+
identity: str = "<|identity:none|>",
|
76 |
+
):
|
77 |
+
prompt = (
|
78 |
+
f"<|bos|>"
|
79 |
+
f"<copyright>{copyright.strip()}</copyright>"
|
80 |
+
f"<character>{character.strip()}</character>"
|
81 |
+
f"{rating}{aspect_ratio}{length}"
|
82 |
+
f"<general>{general.strip()}{identity}<|input_end|>"
|
83 |
+
)
|
84 |
+
|
85 |
+
return prompt
|
86 |
+
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
@spaces.GPU(duration=5)
|
90 |
+
def generate_tags(
|
91 |
+
model,
|
92 |
+
tokenizer: PreTrainedTokenizerBase,
|
93 |
+
prompt: str,
|
94 |
+
):
|
95 |
+
print( # debug
|
96 |
+
tokenizer.tokenize(
|
97 |
+
prompt,
|
98 |
+
add_special_tokens=False,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
input_ids = tokenizer.encode_plus(prompt, return_tensors="pt").input_ids
|
102 |
+
output = model.generate(
|
103 |
+
input_ids.to(model.device),
|
104 |
+
do_sample=True,
|
105 |
+
temperature=1,
|
106 |
+
top_p=0.9,
|
107 |
+
top_k=100,
|
108 |
+
num_beams=1,
|
109 |
+
num_return_sequences=1,
|
110 |
+
max_length=256,
|
111 |
+
)
|
112 |
+
|
113 |
+
# remove input tokens
|
114 |
+
pure_output_ids = output[0][len(input_ids[0]) :]
|
115 |
+
|
116 |
+
return ", ".join(
|
117 |
+
[
|
118 |
+
token
|
119 |
+
for token in tokenizer.batch_decode(
|
120 |
+
pure_output_ids, skip_special_tokens=True
|
121 |
+
)
|
122 |
+
if token.strip() != ""
|
123 |
+
]
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
class V2UI:
|
128 |
+
model_name: str | None = None
|
129 |
+
model: AutoModelForCausalLM
|
130 |
+
tokenizer: PreTrainedTokenizerBase
|
131 |
+
|
132 |
+
input_components: list[Component] = []
|
133 |
+
generate_btn: gr.Button
|
134 |
+
|
135 |
+
def on_generate(
|
136 |
+
self,
|
137 |
+
model_name: str,
|
138 |
+
copyright_tags: str,
|
139 |
+
character_tags: str,
|
140 |
+
general_tags: str,
|
141 |
+
rating_option: str,
|
142 |
+
# aspect_ratio_option: str,
|
143 |
+
length_option: str,
|
144 |
+
identity_option: str,
|
145 |
+
image_size: str, # this is from image generation config
|
146 |
+
*args,
|
147 |
+
) -> UpsamplingOutput:
|
148 |
+
if self.model_name is None or self.model_name != model_name:
|
149 |
+
models = prepare_models(ALL_MODELS[model_name]["repo"])
|
150 |
+
self.model = models["model"]
|
151 |
+
self.tokenizer = models["tokenizer"]
|
152 |
+
self.model_name = model_name
|
153 |
+
|
154 |
+
# normalize tags
|
155 |
+
copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
|
156 |
+
character_tags = normalize_tags(self.tokenizer, character_tags)
|
157 |
+
general_tags = normalize_tags(self.tokenizer, general_tags)
|
158 |
+
|
159 |
+
rating_tag = RATING_OPTIONS[rating_option]
|
160 |
+
aspect_ratio_tag = IMAGE_SIZE_OPTIONS[image_size]
|
161 |
+
length_tag = LENGTH_OPTIONS[length_option]
|
162 |
+
identity_tag = IDENTITY_OPTIONS[identity_option]
|
163 |
+
|
164 |
+
prompt = compose_prompt(
|
165 |
+
copyright=copyright_tags,
|
166 |
+
character=character_tags,
|
167 |
+
general=general_tags,
|
168 |
+
rating=rating_tag,
|
169 |
+
aspect_ratio=aspect_ratio_tag,
|
170 |
+
length=length_tag,
|
171 |
+
identity=identity_tag,
|
172 |
+
)
|
173 |
+
|
174 |
+
start = time.time()
|
175 |
+
upsampled_tags = generate_tags(
|
176 |
+
self.model,
|
177 |
+
self.tokenizer,
|
178 |
+
prompt,
|
179 |
+
)
|
180 |
+
elapsed_time = time.time() - start
|
181 |
+
|
182 |
+
return UpsamplingOutput(
|
183 |
+
upsampled_tags=upsampled_tags,
|
184 |
+
copyright_tags=copyright_tags,
|
185 |
+
character_tags=character_tags,
|
186 |
+
general_tags=general_tags,
|
187 |
+
rating_tag=rating_tag,
|
188 |
+
aspect_ratio_tag=aspect_ratio_tag,
|
189 |
+
length_tag=length_tag,
|
190 |
+
identity_tag=identity_tag,
|
191 |
+
elapsed_time=elapsed_time,
|
192 |
+
)
|
193 |
+
|
194 |
+
def ui(self):
|
195 |
+
input_copyright = gr.Textbox(
|
196 |
+
label="Copyright tags",
|
197 |
+
placeholder="vocaloid",
|
198 |
+
)
|
199 |
+
input_character = gr.Textbox(
|
200 |
+
label="Character tags",
|
201 |
+
placeholder="hatsune miku",
|
202 |
+
)
|
203 |
+
input_general = gr.TextArea(
|
204 |
+
label="General tags",
|
205 |
+
lines=4,
|
206 |
+
placeholder="1girl, ...",
|
207 |
+
value="1girl",
|
208 |
+
)
|
209 |
+
|
210 |
+
input_rating = gr.Radio(
|
211 |
+
label="Rating",
|
212 |
+
choices=list(RATING_OPTIONS.keys()),
|
213 |
+
value="general",
|
214 |
+
)
|
215 |
+
# input_aspect_ratio = gr.Radio(
|
216 |
+
# label="Aspect ratio",
|
217 |
+
# choices=["ultra_wide", "wide", "square", "tall", "ultra_tall"],
|
218 |
+
# value="tall",
|
219 |
+
# )
|
220 |
+
input_length = gr.Radio(
|
221 |
+
label="Length",
|
222 |
+
choices=list(LENGTH_OPTIONS.keys()),
|
223 |
+
value="long",
|
224 |
+
)
|
225 |
+
input_identity = gr.Radio(
|
226 |
+
label="Identity",
|
227 |
+
choices=list(IDENTITY_OPTIONS.keys()),
|
228 |
+
value="lax",
|
229 |
+
)
|
230 |
+
|
231 |
+
model_name = gr.Dropdown(
|
232 |
+
label="Model",
|
233 |
+
choices=list(ALL_MODELS.keys()),
|
234 |
+
value=list(ALL_MODELS.keys())[0],
|
235 |
+
)
|
236 |
+
|
237 |
+
self.generate_btn = gr.Button(value="Generate", variant="primary")
|
238 |
+
|
239 |
+
self.input_components = [
|
240 |
+
model_name,
|
241 |
+
input_copyright,
|
242 |
+
input_character,
|
243 |
+
input_general,
|
244 |
+
input_rating,
|
245 |
+
# input_aspect_ratio,
|
246 |
+
input_length,
|
247 |
+
input_identity,
|
248 |
+
]
|
249 |
+
|
250 |
+
def get_generate_btn(self) -> gr.Button:
|
251 |
+
return self.generate_btn
|
252 |
+
|
253 |
+
def get_inputs(self) -> list[Component]:
|
254 |
+
return self.input_components
|