p1atdev commited on
Commit
cf72c4b
1 Parent(s): 5043faf

initial commit

Browse files
Files changed (8) hide show
  1. .gitignore +176 -0
  2. README.md +3 -3
  3. app.py +185 -0
  4. diffusion.py +71 -0
  5. output.py +16 -0
  6. requirements.txt +5 -0
  7. utils.py +44 -0
  8. v2.py +254 -0
.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Danbooru Tags Transformer V2
3
- emoji: 🐢
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.28.3
8
  app_file: app.py
 
1
  ---
2
  title: Danbooru Tags Transformer V2
3
+ emoji: 📦
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.28.3
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from PIL import Image
3
+
4
+ import gradio as gr
5
+
6
+ from v2 import V2UI
7
+ from diffusion import ImageGenerator
8
+ from output import UpsamplingOutput
9
+ from utils import QUALITY_TAGS, NEGATIVE_PROMPT, IMAGE_SIZE_OPTIONS, IMAGE_SIZES
10
+
11
+
12
+ def animagine_xl_v3_1(output: UpsamplingOutput):
13
+ return ", ".join(
14
+ [
15
+ part.strip()
16
+ for part in [
17
+ output.character_tags,
18
+ output.copyright_tags,
19
+ output.general_tags,
20
+ output.upsampled_tags,
21
+ (
22
+ output.rating_tag
23
+ if output.rating_tag not in ["<|rating:sfw|>", "<|rating:general|>"]
24
+ else ""
25
+ ),
26
+ ]
27
+ if part.strip() != ""
28
+ ]
29
+ )
30
+
31
+
32
+ def elapsed_time_format(elapsed_time: float) -> str:
33
+ return f"Elapsed: {elapsed_time:.2f} seconds"
34
+
35
+
36
+ def parse_upsampling_output(
37
+ upsampler: Callable[..., UpsamplingOutput],
38
+ image_generator: Callable[..., Image.Image],
39
+ ):
40
+ def _parse_upsampling_output(
41
+ generate_image: bool, *args
42
+ ) -> tuple[str, str, Image.Image | None]:
43
+ output = upsampler(*args)
44
+
45
+ print(output)
46
+
47
+ if not generate_image:
48
+ return (
49
+ animagine_xl_v3_1(output),
50
+ elapsed_time_format(output.elapsed_time),
51
+ None,
52
+ )
53
+
54
+ # generate image
55
+ [
56
+ image_size_option,
57
+ quality_tags,
58
+ negative_prompt,
59
+ num_inference_steps,
60
+ guidance_scale,
61
+ ] = args[
62
+ 7:
63
+ ] # remove the first 7 arguments for upsampler
64
+ width, height = IMAGE_SIZES[image_size_option]
65
+ image = image_generator(
66
+ ", ".join([animagine_xl_v3_1(output), quality_tags]),
67
+ negative_prompt,
68
+ height,
69
+ width,
70
+ num_inference_steps,
71
+ guidance_scale,
72
+ )
73
+
74
+ return (
75
+ animagine_xl_v3_1(output),
76
+ elapsed_time_format(output.elapsed_time),
77
+ image,
78
+ )
79
+
80
+ return _parse_upsampling_output
81
+
82
+
83
+ def toggle_visible_output_image(generate_image: bool):
84
+ return gr.update(
85
+ visible=generate_image,
86
+ )
87
+
88
+
89
+ def image_generation_config_ui():
90
+ with gr.Accordion(label="Image generation config", open=True) as accordion:
91
+ image_size = gr.Radio(
92
+ label="Image size",
93
+ choices=list(IMAGE_SIZE_OPTIONS.keys()),
94
+ value=list(IMAGE_SIZE_OPTIONS.keys())[3], # tall
95
+ )
96
+
97
+ quality_tags = gr.Textbox(
98
+ label="Quality tags",
99
+ placeholder=QUALITY_TAGS["default"],
100
+ value=QUALITY_TAGS["default"],
101
+ )
102
+ negative_prompt = gr.Textbox(
103
+ label="Negative prompt",
104
+ placeholder=NEGATIVE_PROMPT["default"],
105
+ value=NEGATIVE_PROMPT["default"],
106
+ )
107
+
108
+ num_inference_steps = gr.Slider(
109
+ label="Num inference steps",
110
+ minimum=20,
111
+ maximum=30,
112
+ step=1,
113
+ value=25,
114
+ )
115
+ guidance_scale = gr.Slider(
116
+ label="Guidance scale",
117
+ minimum=0.0,
118
+ maximum=10.0,
119
+ step=0.5,
120
+ value=7.0,
121
+ )
122
+
123
+ return accordion, [
124
+ image_size,
125
+ quality_tags,
126
+ negative_prompt,
127
+ num_inference_steps,
128
+ guidance_scale,
129
+ ]
130
+
131
+
132
+ def main():
133
+
134
+ v2 = V2UI()
135
+
136
+ print("Loading diffusion model...")
137
+ image_generator = ImageGenerator()
138
+ print("Loaded.")
139
+
140
+ with gr.Blocks() as ui:
141
+ with gr.Row():
142
+ with gr.Column():
143
+ v2.ui()
144
+
145
+ generate_image_check = gr.Checkbox(
146
+ label="Also generate image", value=True
147
+ )
148
+
149
+ accordion, image_generation_config_components = (
150
+ image_generation_config_ui()
151
+ )
152
+
153
+ with gr.Column():
154
+ output_text = gr.TextArea(label="Output tags", interactive=False)
155
+
156
+ elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
157
+
158
+ output_image = gr.Gallery(
159
+ label="Output image",
160
+ columns=1,
161
+ preview=True,
162
+ show_label=False,
163
+ visible=True,
164
+ )
165
+
166
+ v2.get_generate_btn().click(
167
+ parse_upsampling_output(v2.on_generate, image_generator.generate),
168
+ inputs=[
169
+ generate_image_check,
170
+ *v2.get_inputs(),
171
+ *image_generation_config_components,
172
+ ],
173
+ outputs=[output_text, elapsed_time_md, output_image],
174
+ )
175
+ generate_image_check.change(
176
+ toggle_visible_output_image,
177
+ inputs=[generate_image_check],
178
+ outputs=[output_image],
179
+ )
180
+
181
+ ui.launch()
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
diffusion.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
5
+ StableDiffusionXLPipeline,
6
+ )
7
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
8
+ EulerAncestralDiscreteScheduler,
9
+ )
10
+
11
+ try:
12
+ import spaces
13
+ except ImportError:
14
+
15
+ class spaces:
16
+ def GPU(*args, **kwargs):
17
+ return lambda x: x
18
+
19
+
20
+ from utils import NEGATIVE_PROMPT
21
+
22
+
23
+ class ImageGenerator:
24
+ pipe: StableDiffusionXLPipeline
25
+
26
+ def __init__(self, model_name: str = "cagliostrolab/animagine-xl-3.1"):
27
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
28
+ model_name,
29
+ torch_dtype=torch.float16,
30
+ custom_pipeline="lpw_stable_diffusion_xl",
31
+ use_safetensors=True,
32
+ add_watermarker=False,
33
+ )
34
+ self.pipe.bad_punct_regexscheduler = (
35
+ EulerAncestralDiscreteScheduler.from_pretrained(
36
+ model_name,
37
+ subfolder="scheduler",
38
+ )
39
+ )
40
+
41
+ # xformers
42
+ self.pipe.enable_xformers_memory_efficient_attention()
43
+
44
+ self.pipe.to("cuda")
45
+
46
+ @torch.no_grad()
47
+ @spaces.GPU(duration=30)
48
+ def generate(
49
+ self,
50
+ prompt: str,
51
+ negative_prompt: str = NEGATIVE_PROMPT["default"], # Light v3.1
52
+ height: int = 1152,
53
+ width: int = 896,
54
+ num_inference_steps: int = 25,
55
+ guidance_scale: float = 7.0,
56
+ ) -> Image.Image:
57
+ print("prompt", prompt)
58
+ print("negative_prompt", negative_prompt)
59
+ print("height", height)
60
+ print("width", width)
61
+ print("num_inference_steps", num_inference_steps)
62
+ print("guidance_scale", guidance_scale)
63
+
64
+ return self.pipe(
65
+ prompt=prompt,
66
+ negative_prompt=negative_prompt,
67
+ height=height,
68
+ width=width,
69
+ num_inference_steps=num_inference_steps,
70
+ guidance_scale=guidance_scale,
71
+ ).images[0]
output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class UpsamplingOutput:
6
+ upsampled_tags: str
7
+
8
+ copyright_tags: str
9
+ character_tags: str
10
+ general_tags: str
11
+ rating_tag: str
12
+ aspect_ratio_tag: str
13
+ length_tag: str
14
+ identity_tag: str
15
+
16
+ elapsed_time: float = 0.0
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ accelerate==0.29.2
3
+ transformers==4.38.2
4
+ optimum[onnxruntime]==1.19.1
5
+ spaces==0.26.2
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://huggingface.co/spaces/cagliostrolab/animagine-xl-3.1/blob/main/config.py
2
+ QUALITY_TAGS = {
3
+ "default": "(masterpiece), best quality, very aesthetic, perfect face",
4
+ }
5
+ NEGATIVE_PROMPT = {
6
+ "default": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
7
+ }
8
+
9
+
10
+ IMAGE_SIZE_OPTIONS = {
11
+ "1536x640": "<|aspect_ratio:ultra_wide|>",
12
+ "1216x832": "<|aspect_ratio:wide|>",
13
+ "1024x1024": "<|aspect_ratio:square|>",
14
+ "832x1216": "<|aspect_ratio:tall|>",
15
+ "640x1536": "<|aspect_ratio:ultra_tall|>",
16
+ }
17
+ IMAGE_SIZES = {
18
+ "1536x640": (1536, 640),
19
+ "1216x832": (1216, 832),
20
+ "1024x1024": (1024, 1024),
21
+ "832x1216": (832, 1216),
22
+ "640x1536": (640, 1536),
23
+ }
24
+
25
+ RATING_OPTIONS = {
26
+ "sfw": "<|rating:sfw|>",
27
+ "general": "<|rating:general|>",
28
+ "sensitive": "<|rating:sensitive|>",
29
+ "nsfw": "<|rating:nsfw|>",
30
+ "questionable": "<|rating:questionable|>",
31
+ "explicit": "<|rating:explicit|>",
32
+ }
33
+ LENGTH_OPTIONS = {
34
+ "very_short": "<|length:very_short|>",
35
+ "short": "<|length:short|>",
36
+ "medium": "<|length:medium|>",
37
+ "long": "<|length:long|>",
38
+ "very_long": "<|length:very_long|>",
39
+ }
40
+ IDENTITY_OPTIONS = {
41
+ "none": "<|identity:none|>",
42
+ "lax": "<|identity:lax|>",
43
+ "strict": "<|identity:strict|>",
44
+ }
v2.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
5
+
6
+ import gradio as gr
7
+ from gradio.components import Component
8
+
9
+ try:
10
+ import spaces
11
+ except ImportError:
12
+
13
+ class spaces:
14
+ def GPU(*args, **kwargs):
15
+ return lambda x: x
16
+
17
+
18
+ from output import UpsamplingOutput
19
+ from utils import IMAGE_SIZE_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
20
+
21
+ ALL_MODELS = {
22
+ "dart-v2-llama-100m-sft": {
23
+ "repo": "p1atdev/dart-v2-llama-100m-sft",
24
+ "type": "sft",
25
+ },
26
+ "dart-v2-mistral-100m-sft": {
27
+ "repo": "p1atdev/dart-v2-mistral-100m-sft",
28
+ "type": "sft",
29
+ },
30
+ "dart-v2-mixtral-160m-sft": {
31
+ "repo": "p1atdev/dart-v2-mixtral-160m-sft",
32
+ "type": "sft",
33
+ },
34
+ }
35
+
36
+
37
+ def prepare_models(model_name: str):
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ torch_dtype=torch.bfloat16,
42
+ device_map="auto",
43
+ )
44
+
45
+ return {
46
+ "tokenizer": tokenizer,
47
+ "model": model,
48
+ }
49
+
50
+
51
+ def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
52
+ """Just remove unk tokens."""
53
+ return ", ".join(
54
+ tokenizer.batch_decode(
55
+ [
56
+ token
57
+ for token in tokenizer.encode_plus(
58
+ tags,
59
+ return_tensors="pt",
60
+ ).input_ids[0]
61
+ if int(token) != tokenizer.unk_token_id
62
+ ],
63
+ skip_special_tokens=True,
64
+ )
65
+ )
66
+
67
+
68
+ def compose_prompt(
69
+ copyright: str = "",
70
+ character: str = "",
71
+ general: str = "",
72
+ rating: str = "<|rating:sfw|>",
73
+ aspect_ratio: str = "<|aspect_ratio:tall|>",
74
+ length: str = "<|length:long|>",
75
+ identity: str = "<|identity:none|>",
76
+ ):
77
+ prompt = (
78
+ f"<|bos|>"
79
+ f"<copyright>{copyright.strip()}</copyright>"
80
+ f"<character>{character.strip()}</character>"
81
+ f"{rating}{aspect_ratio}{length}"
82
+ f"<general>{general.strip()}{identity}<|input_end|>"
83
+ )
84
+
85
+ return prompt
86
+
87
+
88
+ @torch.no_grad()
89
+ @spaces.GPU(duration=5)
90
+ def generate_tags(
91
+ model,
92
+ tokenizer: PreTrainedTokenizerBase,
93
+ prompt: str,
94
+ ):
95
+ print( # debug
96
+ tokenizer.tokenize(
97
+ prompt,
98
+ add_special_tokens=False,
99
+ )
100
+ )
101
+ input_ids = tokenizer.encode_plus(prompt, return_tensors="pt").input_ids
102
+ output = model.generate(
103
+ input_ids.to(model.device),
104
+ do_sample=True,
105
+ temperature=1,
106
+ top_p=0.9,
107
+ top_k=100,
108
+ num_beams=1,
109
+ num_return_sequences=1,
110
+ max_length=256,
111
+ )
112
+
113
+ # remove input tokens
114
+ pure_output_ids = output[0][len(input_ids[0]) :]
115
+
116
+ return ", ".join(
117
+ [
118
+ token
119
+ for token in tokenizer.batch_decode(
120
+ pure_output_ids, skip_special_tokens=True
121
+ )
122
+ if token.strip() != ""
123
+ ]
124
+ )
125
+
126
+
127
+ class V2UI:
128
+ model_name: str | None = None
129
+ model: AutoModelForCausalLM
130
+ tokenizer: PreTrainedTokenizerBase
131
+
132
+ input_components: list[Component] = []
133
+ generate_btn: gr.Button
134
+
135
+ def on_generate(
136
+ self,
137
+ model_name: str,
138
+ copyright_tags: str,
139
+ character_tags: str,
140
+ general_tags: str,
141
+ rating_option: str,
142
+ # aspect_ratio_option: str,
143
+ length_option: str,
144
+ identity_option: str,
145
+ image_size: str, # this is from image generation config
146
+ *args,
147
+ ) -> UpsamplingOutput:
148
+ if self.model_name is None or self.model_name != model_name:
149
+ models = prepare_models(ALL_MODELS[model_name]["repo"])
150
+ self.model = models["model"]
151
+ self.tokenizer = models["tokenizer"]
152
+ self.model_name = model_name
153
+
154
+ # normalize tags
155
+ copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
156
+ character_tags = normalize_tags(self.tokenizer, character_tags)
157
+ general_tags = normalize_tags(self.tokenizer, general_tags)
158
+
159
+ rating_tag = RATING_OPTIONS[rating_option]
160
+ aspect_ratio_tag = IMAGE_SIZE_OPTIONS[image_size]
161
+ length_tag = LENGTH_OPTIONS[length_option]
162
+ identity_tag = IDENTITY_OPTIONS[identity_option]
163
+
164
+ prompt = compose_prompt(
165
+ copyright=copyright_tags,
166
+ character=character_tags,
167
+ general=general_tags,
168
+ rating=rating_tag,
169
+ aspect_ratio=aspect_ratio_tag,
170
+ length=length_tag,
171
+ identity=identity_tag,
172
+ )
173
+
174
+ start = time.time()
175
+ upsampled_tags = generate_tags(
176
+ self.model,
177
+ self.tokenizer,
178
+ prompt,
179
+ )
180
+ elapsed_time = time.time() - start
181
+
182
+ return UpsamplingOutput(
183
+ upsampled_tags=upsampled_tags,
184
+ copyright_tags=copyright_tags,
185
+ character_tags=character_tags,
186
+ general_tags=general_tags,
187
+ rating_tag=rating_tag,
188
+ aspect_ratio_tag=aspect_ratio_tag,
189
+ length_tag=length_tag,
190
+ identity_tag=identity_tag,
191
+ elapsed_time=elapsed_time,
192
+ )
193
+
194
+ def ui(self):
195
+ input_copyright = gr.Textbox(
196
+ label="Copyright tags",
197
+ placeholder="vocaloid",
198
+ )
199
+ input_character = gr.Textbox(
200
+ label="Character tags",
201
+ placeholder="hatsune miku",
202
+ )
203
+ input_general = gr.TextArea(
204
+ label="General tags",
205
+ lines=4,
206
+ placeholder="1girl, ...",
207
+ value="1girl",
208
+ )
209
+
210
+ input_rating = gr.Radio(
211
+ label="Rating",
212
+ choices=list(RATING_OPTIONS.keys()),
213
+ value="general",
214
+ )
215
+ # input_aspect_ratio = gr.Radio(
216
+ # label="Aspect ratio",
217
+ # choices=["ultra_wide", "wide", "square", "tall", "ultra_tall"],
218
+ # value="tall",
219
+ # )
220
+ input_length = gr.Radio(
221
+ label="Length",
222
+ choices=list(LENGTH_OPTIONS.keys()),
223
+ value="long",
224
+ )
225
+ input_identity = gr.Radio(
226
+ label="Identity",
227
+ choices=list(IDENTITY_OPTIONS.keys()),
228
+ value="lax",
229
+ )
230
+
231
+ model_name = gr.Dropdown(
232
+ label="Model",
233
+ choices=list(ALL_MODELS.keys()),
234
+ value=list(ALL_MODELS.keys())[0],
235
+ )
236
+
237
+ self.generate_btn = gr.Button(value="Generate", variant="primary")
238
+
239
+ self.input_components = [
240
+ model_name,
241
+ input_copyright,
242
+ input_character,
243
+ input_general,
244
+ input_rating,
245
+ # input_aspect_ratio,
246
+ input_length,
247
+ input_identity,
248
+ ]
249
+
250
+ def get_generate_btn(self) -> gr.Button:
251
+ return self.generate_btn
252
+
253
+ def get_inputs(self) -> list[Component]:
254
+ return self.input_components