oimoyu commited on
Commit
faf3806
·
verified ·
1 Parent(s): fc4323f

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  /web/assets/** linguist-generated
2
  /web/** linguist-vendored
3
  comfy/text_encoders/t5_pile_tokenizer/tokenizer.model filter=lfs diff=lfs merge=lfs -text
 
 
1
  /web/assets/** linguist-generated
2
  /web/** linguist-vendored
3
  comfy/text_encoders/t5_pile_tokenizer/tokenizer.model filter=lfs diff=lfs merge=lfs -text
4
+ custom_nodes/comfyui-WD14-Tagger/models/wd-v1-4-moat-tagger-v2.onnx filter=lfs diff=lfs merge=lfs -text
custom_nodes/comfyui-WD14-Tagger/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI WD 1.4 Tagger
2
+
3
+ A [ComfyUI](https://github.com/comfyanonymous/ComfyUI) extension allowing the interrogation of booru tags from images.
4
+
5
+ Based on [SmilingWolf/wd-v1-4-tags](https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags) and [toriato/stable-diffusion-webui-wd14-tagger](https://github.com/toriato/stable-diffusion-webui-wd14-tagger)
6
+ All models created by [SmilingWolf](https://huggingface.co/SmilingWolf)
7
+
8
+ ## Installation
9
+ 1. `git clone https://github.com/pythongosssss/ComfyUI-WD14-Tagger` into the `custom_nodes` folder
10
+ - e.g. `custom_nodes\ComfyUI-WD14-Tagger`
11
+ 2. Open a Command Prompt/Terminal/etc
12
+ 3. Change to the `custom_nodes\ComfyUI-WD14-Tagger` folder you just created
13
+ - e.g. `cd C:\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-WD14-Tagger` or wherever you have it installed
14
+ 4. Install python packages
15
+ - **Windows Standalone installation** (embedded python):
16
+ `../../../python_embeded/python.exe -s -m pip install -r requirements.txt`
17
+ - **Manual/non-Windows installation**
18
+ `pip install -r requirements.txt`
19
+
20
+ ## Usage
21
+ Add the node via `image` -> `WD14Tagger|pysssss`
22
+ ![image](https://github.com/pythongosssss/ComfyUI-WD14-Tagger/assets/125205205/ee6756ae-73f6-4e9f-a3da-eb87a056eb87)
23
+ Models are automatically downloaded at runtime if missing.
24
+ ![image](https://github.com/pythongosssss/ComfyUI-WD14-Tagger/assets/125205205/cc09ae71-1a38-44da-afec-90f470a4b47d)
25
+ Supports tagging and outputting multiple batched inputs.
26
+ - **model**: The interrogation model to use. You can try them out here [WaifuDiffusion v1.4 Tags](https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags). The newest model (as of writing) is `MOAT` and the most popular is `ConvNextV2`.
27
+ - **threshold**: The score for the tag to be considered valid
28
+ - **character_threshold**: The score for the character tag to be considered valid
29
+ - **exclude_tags** A comma separated list of tags that should not be included in the results
30
+
31
+ Quick interrogation of images is also available on any node that is displaying an image, e.g. a `LoadImage`, `SaveImage`, `PreviewImage` node.
32
+ Simply right click on the node (or if displaying multiple images, on the image you want to interrogate) and select `WD14 Tagger` from the menu
33
+ ![image](https://github.com/pythongosssss/ComfyUI-WD14-Tagger/assets/125205205/11733899-6163-49f6-a22b-8dd86d910de6)
34
+
35
+ Settings used for this are in the `settings` section of `pysssss.json`.
36
+
37
+ ### Offline Use
38
+ Simplest way is to use it online, interrogate an image, and the model will be downloaded and cached, however if you want to manually download the models:
39
+ - Create a `models` folder (in same folder as the `wd14tagger.py`)
40
+ - Use URLs for models from the list in `pysssss.json`
41
+ - Download `model.onnx` and name it with the model name e.g. `wd-v1-4-convnext-tagger-v2.onnx`
42
+ - Download `selected_tags.csv` and name it with the model name e.g. `wd-v1-4-convnext-tagger-v2.csv`
43
+
44
+ ## Requirements
45
+ `onnxruntime` (recommended, interrogation is still fast on CPU, included in requirements.txt)
46
+ or `onnxruntime-gpu` (allows use of GPU, many people have issues with this, if you try I can't provide support for this)
47
+
48
+ ## Changelog
49
+ - 2023-05-14 - Moved to own repo, add downloading models, support multiple inputs
custom_nodes/comfyui-WD14-Tagger/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .pysssss import init
2
+
3
+ if init(check_imports=["onnxruntime"]):
4
+ from .wd14tagger import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
5
+ WEB_DIRECTORY = "./web"
6
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
custom_nodes/comfyui-WD14-Tagger/models/wd-v1-4-moat-tagger-v2.csv ADDED
The diff for this file is too large to render. See raw diff
 
custom_nodes/comfyui-WD14-Tagger/models/wd-v1-4-moat-tagger-v2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8cef913be4c9e8d93f9f903e74271416502ce0b4b04df0ff1e2f00df488aa03
3
+ size 326197340
custom_nodes/comfyui-WD14-Tagger/pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "comfyui-wd14-tagger"
3
+ description = "A ComfyUI extension allowing the interrogation of booru tags from images."
4
+ version = "1.0.1"
5
+ license = { file = "LICENSE" }
6
+ dependencies = ["onnxruntime"]
7
+
8
+ [project.urls]
9
+ Repository = "https://github.com/pythongosssss/ComfyUI-WD14-Tagger"
10
+
11
+ [tool.comfy]
12
+ PublisherId = "pythongosssss"
13
+ DisplayName = "ComfyUI-WD14-Tagger"
14
+ Icon = ""
custom_nodes/comfyui-WD14-Tagger/pysssss.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "WD14Tagger",
3
+ "logging": false,
4
+ "settings": {
5
+ "model": "wd-v1-4-moat-tagger-v2",
6
+ "threshold": 0.35,
7
+ "character_threshold": 0.85,
8
+ "exclude_tags": "",
9
+ "ortProviders": ["CUDAExecutionProvider", "CPUExecutionProvider"],
10
+ "HF_ENDPOINT": "https://huggingface.co"
11
+ },
12
+ "models": {
13
+ "wd-eva02-large-tagger-v3": "{HF_ENDPOINT}/SmilingWolf/wd-eva02-large-tagger-v3",
14
+ "wd-vit-tagger-v3": "{HF_ENDPOINT}/SmilingWolf/wd-vit-tagger-v3",
15
+ "wd-swinv2-tagger-v3": "{HF_ENDPOINT}/SmilingWolf/wd-swinv2-tagger-v3",
16
+ "wd-convnext-tagger-v3": "{HF_ENDPOINT}/SmilingWolf/wd-convnext-tagger-v3",
17
+ "wd-v1-4-moat-tagger-v2": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-moat-tagger-v2",
18
+ "wd-v1-4-convnextv2-tagger-v2": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
19
+ "wd-v1-4-convnext-tagger-v2": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-convnext-tagger-v2",
20
+ "wd-v1-4-convnext-tagger": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-convnext-tagger",
21
+ "wd-v1-4-vit-tagger-v2": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-vit-tagger-v2",
22
+ "wd-v1-4-swinv2-tagger-v2": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-swinv2-tagger-v2",
23
+ "wd-v1-4-vit-tagger": "{HF_ENDPOINT}/SmilingWolf/wd-v1-4-vit-tagger"
24
+ }
25
+ }
custom_nodes/comfyui-WD14-Tagger/pysssss.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ import shutil
5
+ import inspect
6
+ import aiohttp
7
+ from server import PromptServer
8
+ from tqdm import tqdm
9
+
10
+ config = None
11
+
12
+
13
+ def is_logging_enabled():
14
+ config = get_extension_config()
15
+ if "logging" not in config:
16
+ return False
17
+ return config["logging"]
18
+
19
+
20
+ def log(message, type=None, always=False):
21
+ if not always and not is_logging_enabled():
22
+ return
23
+
24
+ if type is not None:
25
+ message = f"[{type}] {message}"
26
+
27
+ name = get_extension_config()["name"]
28
+
29
+ print(f"(pysssss:{name}) {message}")
30
+
31
+
32
+ def get_ext_dir(subpath=None, mkdir=False):
33
+ dir = os.path.dirname(__file__)
34
+ if subpath is not None:
35
+ dir = os.path.join(dir, subpath)
36
+
37
+ dir = os.path.abspath(dir)
38
+
39
+ if mkdir and not os.path.exists(dir):
40
+ os.makedirs(dir)
41
+ return dir
42
+
43
+
44
+ def get_comfy_dir(subpath=None):
45
+ dir = os.path.dirname(inspect.getfile(PromptServer))
46
+ if subpath is not None:
47
+ dir = os.path.join(dir, subpath)
48
+
49
+ dir = os.path.abspath(dir)
50
+
51
+ return dir
52
+
53
+
54
+ def get_web_ext_dir():
55
+ config = get_extension_config()
56
+ name = config["name"]
57
+ dir = get_comfy_dir("web/extensions/pysssss")
58
+ if not os.path.exists(dir):
59
+ os.makedirs(dir)
60
+ dir += "/" + name
61
+ return dir
62
+
63
+
64
+ def get_extension_config(reload=False):
65
+ global config
66
+ if reload == False and config is not None:
67
+ return config
68
+
69
+ config_path = get_ext_dir("pysssss.user.json")
70
+ if not os.path.exists(config_path):
71
+ config_path = get_ext_dir("pysssss.json")
72
+
73
+ if not os.path.exists(config_path):
74
+ log("Missing pysssss.json and pysssss.user.json, this extension may not work correctly. Please reinstall the extension.",
75
+ type="ERROR", always=True)
76
+ print(f"Extension path: {get_ext_dir()}")
77
+ return {"name": "Unknown", "version": -1}
78
+ with open(config_path, "r") as f:
79
+ config = json.loads(f.read())
80
+ return config
81
+
82
+ def link_js(src, dst):
83
+ src = os.path.abspath(src)
84
+ dst = os.path.abspath(dst)
85
+ if os.name == "nt":
86
+ try:
87
+ import _winapi
88
+ _winapi.CreateJunction(src, dst)
89
+ return True
90
+ except:
91
+ pass
92
+ try:
93
+ os.symlink(src, dst)
94
+ return True
95
+ except:
96
+ import logging
97
+ logging.exception('')
98
+ return False
99
+
100
+
101
+ def is_junction(path):
102
+ if os.name != "nt":
103
+ return False
104
+ try:
105
+ return bool(os.readlink(path))
106
+ except OSError:
107
+ return False
108
+
109
+ def install_js():
110
+ src_dir = get_ext_dir("web/js")
111
+ if not os.path.exists(src_dir):
112
+ log("No JS")
113
+ return
114
+
115
+ should_install = should_install_js()
116
+ if should_install:
117
+ log("it looks like you're running an old version of ComfyUI that requires manual setup of web files, it is recommended you update your installation.", "warning", True)
118
+ dst_dir = get_web_ext_dir()
119
+ linked = os.path.islink(dst_dir) or is_junction(dst_dir)
120
+ if linked or os.path.exists(dst_dir):
121
+ if linked:
122
+ if should_install:
123
+ log("JS already linked")
124
+ else:
125
+ os.unlink(dst_dir)
126
+ log("JS unlinked, PromptServer will serve extension")
127
+ elif not should_install:
128
+ shutil.rmtree(dst_dir)
129
+ log("JS deleted, PromptServer will serve extension")
130
+ return
131
+
132
+ if not should_install:
133
+ log("JS skipped, PromptServer will serve extension")
134
+ return
135
+
136
+ if link_js(src_dir, dst_dir):
137
+ log("JS linked")
138
+ return
139
+
140
+ log("Copying JS files")
141
+ shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True)
142
+
143
+
144
+ def should_install_js():
145
+ return not hasattr(PromptServer.instance, "supports") or "custom_nodes_from_web" not in PromptServer.instance.supports
146
+
147
+
148
+ def init(check_imports):
149
+ log("Init")
150
+
151
+ if check_imports is not None:
152
+ import importlib.util
153
+ for imp in check_imports:
154
+ spec = importlib.util.find_spec(imp)
155
+ if spec is None:
156
+ log(f"{imp} is required, please check requirements are installed.", type="ERROR", always=True)
157
+ return False
158
+
159
+ install_js()
160
+ return True
161
+
162
+
163
+ async def download_to_file(url, destination, update_callback, is_ext_subpath=True, session=None):
164
+ close_session = False
165
+ if session is None:
166
+ close_session = True
167
+ loop = None
168
+ try:
169
+ loop = asyncio.get_event_loop()
170
+ except:
171
+ loop = asyncio.new_event_loop()
172
+ asyncio.set_event_loop(loop)
173
+
174
+ session = aiohttp.ClientSession(loop=loop)
175
+ if is_ext_subpath:
176
+ destination = get_ext_dir(destination)
177
+ try:
178
+ proxy = os.getenv("HTTP_PROXY") or os.getenv("http_proxy")
179
+ print("proxy:", proxy)
180
+ proxy_auth = None
181
+ if proxy:
182
+ proxy_auth = aiohttp.BasicAuth(os.getenv("PROXY_USER", ""), os.getenv("PROXY_PASS", ""))
183
+
184
+ async with session.get(url, proxy=proxy, proxy_auth=proxy_auth) as response:
185
+ size = int(response.headers.get('content-length', 0)) or None
186
+
187
+ with tqdm(
188
+ unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1], total=size,
189
+ ) as progressbar:
190
+ with open(destination, mode='wb') as f:
191
+ perc = 0
192
+ async for chunk in response.content.iter_chunked(2048):
193
+ f.write(chunk)
194
+ progressbar.update(len(chunk))
195
+ if update_callback is not None and progressbar.total is not None and progressbar.total != 0:
196
+ last = perc
197
+ perc = round(progressbar.n / progressbar.total, 2)
198
+ if perc != last:
199
+ last = perc
200
+ await update_callback(perc)
201
+ finally:
202
+ if close_session and session is not None:
203
+ await session.close()
204
+
205
+ def wait_for_async(async_fn):
206
+ try:
207
+ import concurrent.futures
208
+ # Check if we're in a running event loop
209
+ asyncio.get_running_loop()
210
+ # We're in a running loop, so run the async function in a separate thread
211
+ with concurrent.futures.ThreadPoolExecutor() as executor:
212
+ future = executor.submit(asyncio.run, async_fn())
213
+ return future.result() # This blocks until complete
214
+ except RuntimeError:
215
+ # No running loop, safe to use asyncio.run()
216
+ return asyncio.run(async_fn())
217
+
218
+ def update_node_status(client_id, node, text, progress=None):
219
+ if client_id is None:
220
+ client_id = PromptServer.instance.client_id
221
+
222
+ if client_id is None:
223
+ return
224
+
225
+ PromptServer.instance.send_sync("pysssss/update_status", {
226
+ "node": node,
227
+ "progress": progress,
228
+ "text": text
229
+ }, client_id)
230
+
231
+ async def update_node_status_async(client_id, node, text, progress=None):
232
+ if client_id is None:
233
+ client_id = PromptServer.instance.client_id
234
+
235
+ if client_id is None:
236
+ return
237
+
238
+ await PromptServer.instance.send("pysssss/update_status", {
239
+ "node": node,
240
+ "progress": progress,
241
+ "text": text
242
+ }, client_id)
custom_nodes/comfyui-WD14-Tagger/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime
custom_nodes/comfyui-WD14-Tagger/wd14tagger.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
2
+
3
+ import comfy.utils
4
+ import asyncio
5
+ import aiohttp
6
+ import numpy as np
7
+ import csv
8
+ import os
9
+ import sys
10
+ import onnxruntime as ort
11
+ from onnxruntime import InferenceSession
12
+ from PIL import Image
13
+ from server import PromptServer
14
+ from aiohttp import web
15
+ import folder_paths
16
+ from .pysssss import get_ext_dir, get_comfy_dir, download_to_file, update_node_status, wait_for_async, get_extension_config, log
17
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
18
+
19
+ config = get_extension_config()
20
+
21
+ defaults = {
22
+ "model": "wd-v1-4-moat-tagger-v2",
23
+ "threshold": 0.35,
24
+ "character_threshold": 0.85,
25
+ "replace_underscore": False,
26
+ "trailing_comma": False,
27
+ "exclude_tags": "",
28
+ "ortProviders": ["CUDAExecutionProvider", "CPUExecutionProvider"],
29
+ "HF_ENDPOINT": "https://huggingface.co"
30
+ }
31
+ defaults.update(config.get("settings", {}))
32
+
33
+ if "wd14_tagger" in folder_paths.folder_names_and_paths:
34
+ models_dir = folder_paths.get_folder_paths("wd14_tagger")[0]
35
+ if not os.path.exists(models_dir):
36
+ os.makedirs(models_dir)
37
+ else:
38
+ models_dir = get_ext_dir("models", mkdir=True)
39
+ known_models = list(config["models"].keys())
40
+
41
+ log("Available ORT providers: " + ", ".join(ort.get_available_providers()), "DEBUG", True)
42
+ log("Using ORT providers: " + ", ".join(defaults["ortProviders"]), "DEBUG", True)
43
+
44
+ def get_installed_models():
45
+ models = filter(lambda x: x.endswith(".onnx"), os.listdir(models_dir))
46
+ models = [m for m in models if os.path.exists(os.path.join(models_dir, os.path.splitext(m)[0] + ".csv"))]
47
+ return models
48
+
49
+
50
+ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None):
51
+ if model_name.endswith(".onnx"):
52
+ model_name = model_name[0:-5]
53
+ installed = list(get_installed_models())
54
+ if not any(model_name + ".onnx" in s for s in installed):
55
+ await download_model(model_name, client_id, node)
56
+
57
+ name = os.path.join(models_dir, model_name + ".onnx")
58
+ model = InferenceSession(name, providers=defaults["ortProviders"])
59
+
60
+ input = model.get_inputs()[0]
61
+ height = input.shape[1]
62
+
63
+ # Reduce to max size and pad with white
64
+ ratio = float(height)/max(image.size)
65
+ new_size = tuple([int(x*ratio) for x in image.size])
66
+ image = image.resize(new_size, Image.LANCZOS)
67
+ square = Image.new("RGB", (height, height), (255, 255, 255))
68
+ square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))
69
+
70
+ image = np.array(square).astype(np.float32)
71
+ image = image[:, :, ::-1] # RGB -> BGR
72
+ image = np.expand_dims(image, 0)
73
+
74
+ # Read all tags from csv and locate start of each category
75
+ tags = []
76
+ general_index = None
77
+ character_index = None
78
+ with open(os.path.join(models_dir, model_name + ".csv")) as f:
79
+ reader = csv.reader(f)
80
+ next(reader)
81
+ for row in reader:
82
+ if general_index is None and row[2] == "0":
83
+ general_index = reader.line_num - 2
84
+ elif character_index is None and row[2] == "4":
85
+ character_index = reader.line_num - 2
86
+ if replace_underscore:
87
+ tags.append(row[1].replace("_", " "))
88
+ else:
89
+ tags.append(row[1])
90
+
91
+ label_name = model.get_outputs()[0].name
92
+ probs = model.run([label_name], {input.name: image})[0]
93
+
94
+ result = list(zip(tags, probs[0]))
95
+
96
+ # rating = max(result[:general_index], key=lambda x: x[1])
97
+ general = [item for item in result[general_index:character_index] if item[1] > threshold]
98
+ character = [item for item in result[character_index:] if item[1] > character_threshold]
99
+
100
+ all = character + general
101
+ remove = [s.strip() for s in exclude_tags.lower().split(",")]
102
+ all = [tag for tag in all if tag[0] not in remove]
103
+
104
+ res = ("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all))
105
+
106
+ print(res)
107
+ return res
108
+
109
+
110
+ async def download_model(model, client_id, node):
111
+ hf_endpoint = os.getenv("HF_ENDPOINT", defaults["HF_ENDPOINT"])
112
+ if not hf_endpoint.startswith("https://"):
113
+ hf_endpoint = f"https://{hf_endpoint}"
114
+ if hf_endpoint.endswith("/"):
115
+ hf_endpoint = hf_endpoint.rstrip("/")
116
+
117
+ url = config["models"][model]
118
+ url = url.replace("{HF_ENDPOINT}", hf_endpoint)
119
+ url = f"{url}/resolve/main/"
120
+ async with aiohttp.ClientSession(loop=asyncio.get_event_loop()) as session:
121
+ async def update_callback(perc):
122
+ nonlocal client_id
123
+ message = ""
124
+ if perc < 100:
125
+ message = f"Downloading {model}"
126
+ update_node_status(client_id, node, message, perc)
127
+
128
+ try:
129
+ await download_to_file(
130
+ f"{url}model.onnx", os.path.join(models_dir,f"{model}.onnx"), update_callback, session=session)
131
+ await download_to_file(
132
+ f"{url}selected_tags.csv", os.path.join(models_dir,f"{model}.csv"), update_callback, session=session)
133
+ except aiohttp.client_exceptions.ClientConnectorError as err:
134
+ log("Unable to download model. Download files manually or try using a HF mirror/proxy website by setting the environment variable HF_ENDPOINT=https://.....", "ERROR", True)
135
+ raise
136
+
137
+ update_node_status(client_id, node, None)
138
+
139
+ return web.Response(status=200)
140
+
141
+
142
+ @PromptServer.instance.routes.get("/pysssss/wd14tagger/tag")
143
+ async def get_tags(request):
144
+ if "filename" not in request.rel_url.query:
145
+ return web.Response(status=404)
146
+
147
+ type = request.query.get("type", "output")
148
+ if type not in ["output", "input", "temp"]:
149
+ return web.Response(status=400)
150
+
151
+ target_dir = get_comfy_dir(type)
152
+ image_path = os.path.abspath(os.path.join(
153
+ target_dir, request.query.get("subfolder", ""), request.query["filename"]))
154
+ c = os.path.commonpath((image_path, target_dir))
155
+ if os.path.commonpath((image_path, target_dir)) != target_dir:
156
+ return web.Response(status=403)
157
+
158
+ if not os.path.isfile(image_path):
159
+ return web.Response(status=404)
160
+
161
+ image = Image.open(image_path)
162
+
163
+ models = get_installed_models()
164
+ default = defaults["model"] + ".onnx"
165
+ model = default if default in models else models[0]
166
+
167
+ return web.json_response(await tag(image, model, client_id=request.rel_url.query.get("clientId", ""), node=request.rel_url.query.get("node", "")))
168
+
169
+
170
+ class WD14Tagger:
171
+ @classmethod
172
+ def INPUT_TYPES(s):
173
+ extra = [name for name, _ in (os.path.splitext(m) for m in get_installed_models()) if name not in known_models]
174
+ models = known_models + extra
175
+ return {"required": {
176
+ "image": ("IMAGE", ),
177
+ "model": (models, { "default": defaults["model"] }),
178
+ "threshold": ("FLOAT", {"default": defaults["threshold"], "min": 0.0, "max": 1, "step": 0.05}),
179
+ "character_threshold": ("FLOAT", {"default": defaults["character_threshold"], "min": 0.0, "max": 1, "step": 0.05}),
180
+ "replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}),
181
+ "trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}),
182
+ "exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}),
183
+ }}
184
+
185
+ RETURN_TYPES = ("STRING",)
186
+ OUTPUT_IS_LIST = (True,)
187
+ FUNCTION = "tag"
188
+ OUTPUT_NODE = True
189
+
190
+ CATEGORY = "image"
191
+
192
+ def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False):
193
+ tensor = image*255
194
+ tensor = np.array(tensor, dtype=np.uint8)
195
+
196
+ pbar = comfy.utils.ProgressBar(tensor.shape[0])
197
+ tags = []
198
+ for i in range(tensor.shape[0]):
199
+ image = Image.fromarray(tensor[i])
200
+ tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)))
201
+ pbar.update(1)
202
+ return {"ui": {"tags": tags}, "result": (tags,)}
203
+
204
+
205
+ NODE_CLASS_MAPPINGS = {
206
+ "WD14Tagger|pysssss": WD14Tagger,
207
+ }
208
+ NODE_DISPLAY_NAME_MAPPINGS = {
209
+ "WD14Tagger|pysssss": "WD14 Tagger 🐍",
210
+ }