x1101 commited on
Commit
03df755
1 Parent(s): 0a3ee81

Upload tag_autocomplete_helper.py

Browse files
Files changed (1) hide show
  1. tag_autocomplete_helper.py +284 -0
tag_autocomplete_helper.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This helper script scans folders for wildcards and embeddings and writes them
2
+ # to a temporary file to expose it to the javascript side
3
+
4
+ import gradio as gr
5
+ from pathlib import Path
6
+ from modules import scripts, script_callbacks, shared, sd_hijack
7
+ #from modules.paths import script_path, extensions_dir
8
+ import yaml
9
+
10
+ # Webui root path
11
+ FILE_DIR = Path().absolute()
12
+ #FILE_DIR = Path(script_path)
13
+
14
+
15
+ # The extension base path
16
+ EXT_PATH = FILE_DIR.joinpath('extensions')
17
+ #EXT_PATH = Path(extensions_dir)
18
+
19
+
20
+ # Tags base path
21
+ TAGS_PATH = Path(scripts.basedir()).joinpath('tags')
22
+
23
+ # The path to the folder containing the wildcards and embeddings
24
+ WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards')
25
+ EMB_PATH = Path(shared.cmd_opts.embeddings_dir)
26
+ HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir)
27
+
28
+ try:
29
+ LORA_PATH = Path(shared.cmd_opts.lora_dir)
30
+ except AttributeError:
31
+ LORA_PATH = None
32
+
33
+ def find_ext_wildcard_paths():
34
+ """Returns the path to the extension wildcards folder"""
35
+ found = list(EXT_PATH.glob('*/wildcards/'))
36
+ return found
37
+
38
+
39
+ # The path to the extension wildcards folder
40
+ WILDCARD_EXT_PATHS = find_ext_wildcard_paths()
41
+
42
+ # The path to the temporary files
43
+ STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') # In the webui root, on windows it exists by default, on linux it doesn't
44
+ TEMP_PATH = TAGS_PATH.joinpath('temp') # Extension specific temp files
45
+
46
+
47
+ def get_wildcards():
48
+ """Returns a list of all wildcards. Works on nested folders."""
49
+ wildcard_files = list(WILDCARD_PATH.rglob("*.txt"))
50
+ resolved = [w.relative_to(WILDCARD_PATH).as_posix(
51
+ ) for w in wildcard_files if w.name != "put wildcards here.txt"]
52
+ return resolved
53
+
54
+
55
+ def get_ext_wildcards():
56
+ """Returns a list of all extension wildcards. Works on nested folders."""
57
+ wildcard_files = []
58
+
59
+ for path in WILDCARD_EXT_PATHS:
60
+ wildcard_files.append(path.relative_to(FILE_DIR).as_posix())
61
+ wildcard_files.extend(p.relative_to(path).as_posix() for p in path.rglob("*.txt") if p.name != "put wildcards here.txt")
62
+ wildcard_files.append("-----")
63
+
64
+ return wildcard_files
65
+
66
+
67
+ def get_ext_wildcard_tags():
68
+ """Returns a list of all tags found in extension YAML files found under a Tags: key."""
69
+ wildcard_tags = {} # { tag: count }
70
+ yaml_files = []
71
+ for path in WILDCARD_EXT_PATHS:
72
+ yaml_files.extend(p for p in path.rglob("*.yml"))
73
+ yaml_files.extend(p for p in path.rglob("*.yaml"))
74
+ count = 0
75
+ for path in yaml_files:
76
+ try:
77
+ with open(path, encoding="utf8") as file:
78
+ data = yaml.safe_load(file)
79
+ for item in data:
80
+ if data[item] and 'Tags' in data[item]:
81
+ wildcard_tags[count] = ','.join(data[item]['Tags'])
82
+ count += 1
83
+ else:
84
+ print('Issue with tags found in ' + path.name + ' at item ' + item)
85
+ except yaml.YAMLError as exc:
86
+ print(exc)
87
+ # Sort by count
88
+ sorted_tags = sorted(wildcard_tags.items(), key=lambda item: item[1], reverse=True)
89
+ output = []
90
+ for tag, count in sorted_tags:
91
+ output.append(f"{tag},{count}")
92
+ return output
93
+
94
+
95
+ def get_embeddings(sd_model):
96
+ """Write a list of all embeddings with their version"""
97
+
98
+ # Version constants
99
+ V1_SHAPE = 768
100
+ V2_SHAPE = 1024
101
+ emb_v1 = []
102
+ emb_v2 = []
103
+ results = []
104
+
105
+ try:
106
+ # Get embedding dict from sd_hijack to separate v1/v2 embeddings
107
+ emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings
108
+ emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings
109
+ # Get the shape of the first item in the dict
110
+ emb_a_shape = -1
111
+ emb_b_shape = -1
112
+ if (len(emb_type_a) > 0):
113
+ emb_a_shape = next(iter(emb_type_a.items()))[1].shape
114
+ if (len(emb_type_b) > 0):
115
+ emb_b_shape = next(iter(emb_type_b.items()))[1].shape
116
+
117
+ # Add embeddings to the correct list
118
+ if (emb_a_shape == V1_SHAPE):
119
+ emb_v1 = list(emb_type_a.keys())
120
+ elif (emb_a_shape == V2_SHAPE):
121
+ emb_v2 = list(emb_type_a.keys())
122
+
123
+ if (emb_b_shape == V1_SHAPE):
124
+ emb_v1 = list(emb_type_b.keys())
125
+ elif (emb_b_shape == V2_SHAPE):
126
+ emb_v2 = list(emb_type_b.keys())
127
+
128
+ # Get shape of current model
129
+ #vec = sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
130
+ #model_shape = vec.shape[1]
131
+ # Show relevant entries at the top
132
+ #if (model_shape == V1_SHAPE):
133
+ # results = [e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2]
134
+ #elif (model_shape == V2_SHAPE):
135
+ # results = [e + ",v2" for e in emb_v2] + [e + ",v1" for e in emb_v1]
136
+ #else:
137
+ # raise AttributeError # Fallback to old method
138
+ results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower())
139
+ except AttributeError:
140
+ print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.")
141
+ # Get a list of all embeddings in the folder
142
+ all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.rglob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}]
143
+ # Remove files with a size of 0
144
+ all_embeds = [e for e in all_embeds if EMB_PATH.joinpath(e).stat().st_size > 0]
145
+ # Remove file extensions
146
+ all_embeds = [e[:e.rfind('.')] for e in all_embeds]
147
+ results = [e + "," for e in all_embeds]
148
+
149
+ write_to_temp_file('emb.txt', results)
150
+
151
+ def get_hypernetworks():
152
+ """Write a list of all hypernetworks"""
153
+
154
+ # Get a list of all hypernetworks in the folder
155
+ all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}]
156
+ # Remove file extensions
157
+ return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower())
158
+
159
+ def get_lora():
160
+ """Write a list of all lora"""
161
+
162
+ # Get a list of all lora in the folder
163
+ all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}]
164
+ # Remove file extensions
165
+ return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower())
166
+
167
+
168
+ def write_tag_base_path():
169
+ """Writes the tag base path to a fixed location temporary file"""
170
+ with open(STATIC_TEMP_PATH.joinpath('tagAutocompletePath.txt'), 'w', encoding="utf-8") as f:
171
+ f.write(TAGS_PATH.relative_to(FILE_DIR).as_posix())
172
+
173
+
174
+ def write_to_temp_file(name, data):
175
+ """Writes the given data to a temporary file"""
176
+ with open(TEMP_PATH.joinpath(name), 'w', encoding="utf-8") as f:
177
+ f.write(('\n'.join(data)))
178
+
179
+
180
+ csv_files = []
181
+ csv_files_withnone = []
182
+ def update_tag_files():
183
+ """Returns a list of all potential tag files"""
184
+ global csv_files, csv_files_withnone
185
+ files = [str(t.relative_to(TAGS_PATH)) for t in TAGS_PATH.glob("*.csv")]
186
+ csv_files = files
187
+ csv_files_withnone = ["None"] + files
188
+
189
+
190
+
191
+ # Write the tag base path to a fixed location temporary file
192
+ # to enable the javascript side to find our files regardless of extension folder name
193
+ if not STATIC_TEMP_PATH.exists():
194
+ STATIC_TEMP_PATH.mkdir(exist_ok=True)
195
+
196
+ write_tag_base_path()
197
+ update_tag_files()
198
+
199
+ # Check if the temp path exists and create it if not
200
+ if not TEMP_PATH.exists():
201
+ TEMP_PATH.mkdir(parents=True, exist_ok=True)
202
+
203
+ # Set up files to ensure the script doesn't fail to load them
204
+ # even if no wildcards or embeddings are found
205
+ write_to_temp_file('wc.txt', [])
206
+ write_to_temp_file('wce.txt', [])
207
+ write_to_temp_file('wcet.txt', [])
208
+ write_to_temp_file('hyp.txt', [])
209
+ write_to_temp_file('lora.txt', [])
210
+ # Only reload embeddings if the file doesn't exist, since they are already re-written on model load
211
+ if not TEMP_PATH.joinpath("emb.txt").exists():
212
+ write_to_temp_file('emb.txt', [])
213
+
214
+ # Write wildcards to wc.txt if found
215
+ if WILDCARD_PATH.exists():
216
+ wildcards = [WILDCARD_PATH.relative_to(FILE_DIR).as_posix()] + get_wildcards()
217
+ if wildcards:
218
+ write_to_temp_file('wc.txt', wildcards)
219
+
220
+ # Write extension wildcards to wce.txt if found
221
+ if WILDCARD_EXT_PATHS is not None:
222
+ wildcards_ext = get_ext_wildcards()
223
+ if wildcards_ext:
224
+ write_to_temp_file('wce.txt', wildcards_ext)
225
+ # Write yaml extension wildcards to wcet.txt if found
226
+ wildcards_yaml_ext = get_ext_wildcard_tags()
227
+ if wildcards_yaml_ext:
228
+ write_to_temp_file('wcet.txt', wildcards_yaml_ext)
229
+
230
+ # Write embeddings to emb.txt if found
231
+ if EMB_PATH.exists():
232
+ # Get embeddings after the model loaded callback
233
+ script_callbacks.on_model_loaded(get_embeddings)
234
+
235
+ if HYP_PATH.exists():
236
+ hypernets = get_hypernetworks()
237
+ if hypernets:
238
+ write_to_temp_file('hyp.txt', hypernets)
239
+
240
+ if LORA_PATH is not None and LORA_PATH.exists():
241
+ lora = get_lora()
242
+ if lora:
243
+ write_to_temp_file('lora.txt', lora)
244
+
245
+ # Register autocomplete options
246
+ def on_ui_settings():
247
+ TAC_SECTION = ("tac", "Tag Autocomplete")
248
+ # Main tag file
249
+ shared.opts.add_option("tac_tagFile", shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION))
250
+ # Active in settings
251
+ shared.opts.add_option("tac_active", shared.OptionInfo(True, "Enable Tag Autocompletion", section=TAC_SECTION))
252
+ shared.opts.add_option("tac_activeIn.txt2img", shared.OptionInfo(True, "Active in txt2img (Requires restart)", section=TAC_SECTION))
253
+ shared.opts.add_option("tac_activeIn.img2img", shared.OptionInfo(True, "Active in img2img (Requires restart)", section=TAC_SECTION))
254
+ shared.opts.add_option("tac_activeIn.negativePrompts", shared.OptionInfo(True, "Active in negative prompts (Requires restart)", section=TAC_SECTION))
255
+ shared.opts.add_option("tac_activeIn.thirdParty", shared.OptionInfo(True, "Active in third party textboxes [Dataset Tag Editor] (Requires restart)", section=TAC_SECTION))
256
+ shared.opts.add_option("tac_activeIn.modelList", shared.OptionInfo("", "List of model names (with file extension) or their hashes to use as black/whitelist, separated by commas.", section=TAC_SECTION))
257
+ shared.opts.add_option("tac_activeIn.modelListMode", shared.OptionInfo("Blacklist", "Mode to use for model list", gr.Dropdown, lambda: {"choices": ["Blacklist","Whitelist"]}, section=TAC_SECTION))
258
+ # Results related settings
259
+ shared.opts.add_option("tac_slidingPopup", shared.OptionInfo(True, "Move completion popup together with text cursor", section=TAC_SECTION))
260
+ shared.opts.add_option("tac_maxResults", shared.OptionInfo(5, "Maximum results", section=TAC_SECTION))
261
+ shared.opts.add_option("tac_showAllResults", shared.OptionInfo(False, "Show all results", section=TAC_SECTION))
262
+ shared.opts.add_option("tac_resultStepLength", shared.OptionInfo(100, "How many results to load at once", section=TAC_SECTION))
263
+ shared.opts.add_option("tac_delayTime", shared.OptionInfo(100, "Time in ms to wait before triggering completion again (Requires restart)", section=TAC_SECTION))
264
+ shared.opts.add_option("tac_useWildcards", shared.OptionInfo(True, "Search for wildcards", section=TAC_SECTION))
265
+ shared.opts.add_option("tac_useEmbeddings", shared.OptionInfo(True, "Search for embeddings", section=TAC_SECTION))
266
+ shared.opts.add_option("tac_useHypernetworks", shared.OptionInfo(True, "Search for hypernetworks", section=TAC_SECTION))
267
+ shared.opts.add_option("tac_useLoras", shared.OptionInfo(True, "Search for Loras", section=TAC_SECTION))
268
+ shared.opts.add_option("tac_showWikiLinks", shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page (Warning: This is an external site and very likely contains NSFW examples!)", section=TAC_SECTION))
269
+ # Insertion related settings
270
+ shared.opts.add_option("tac_replaceUnderscores", shared.OptionInfo(True, "Replace underscores with spaces on insertion", section=TAC_SECTION))
271
+ shared.opts.add_option("tac_escapeParentheses", shared.OptionInfo(True, "Escape parentheses on insertion", section=TAC_SECTION))
272
+ shared.opts.add_option("tac_appendComma", shared.OptionInfo(True, "Append comma on tag autocompletion", section=TAC_SECTION))
273
+ # Alias settings
274
+ shared.opts.add_option("tac_alias.searchByAlias", shared.OptionInfo(True, "Search by alias", section=TAC_SECTION))
275
+ shared.opts.add_option("tac_alias.onlyShowAlias", shared.OptionInfo(False, "Only show alias", section=TAC_SECTION))
276
+ # Translation settings
277
+ shared.opts.add_option("tac_translation.translationFile", shared.OptionInfo("None", "Translation filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION))
278
+ shared.opts.add_option("tac_translation.oldFormat", shared.OptionInfo(False, "Translation file uses old 3-column translation format instead of the new 2-column one", section=TAC_SECTION))
279
+ shared.opts.add_option("tac_translation.searchByTranslation", shared.OptionInfo(True, "Search by translation", section=TAC_SECTION))
280
+ # Extra file settings
281
+ shared.opts.add_option("tac_extra.extraFile", shared.OptionInfo("extra-quality-tags.csv", "Extra filename (for small sets of custom tags)", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION))
282
+ shared.opts.add_option("tac_extra.addMode", shared.OptionInfo("Insert before", "Mode to add the extra tags to the main tag list", gr.Dropdown, lambda: {"choices": ["Insert before","Insert after"]}, section=TAC_SECTION))
283
+
284
+ script_callbacks.on_ui_settings(on_ui_settings)