John6666 commited on
Commit
ead8433
1 Parent(s): 41d8f56

Upload 15 files

Browse files
Files changed (3) hide show
  1. character_series_dict.csv +0 -0
  2. tag_group.csv +0 -0
  3. tagger.py +67 -27
character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
tag_group.csv CHANGED
The diff for this file is too large to render. See raw diff
 
tagger.py CHANGED
@@ -59,26 +59,7 @@ def load_dict_from_csv(filename):
59
  return dict
60
 
61
 
62
- def get_series_dict():
63
- import re
64
- with open('characterfull.txt', 'r') as f:
65
- lines = f.readlines()
66
- series_dict = {}
67
- for line in lines:
68
- parts = line.strip().split(', ')
69
- if len(parts) >= 3:
70
- name = parts[-2].replace("\\", "")
71
- if name.endswith(")"):
72
- names = name.split("(")
73
- character_name = "(".join(names[:-1])
74
- if character_name.endswith(" "):
75
- name = character_name[:-1]
76
- series = re.sub(r'\\[()]', '', parts[-1])
77
- series_dict[name] = series
78
- return series_dict
79
-
80
-
81
- anime_series_dict = get_series_dict()
82
 
83
 
84
  def character_list_to_series_list(character_list):
@@ -248,7 +229,7 @@ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "Non
248
  def list_uniq(l):
249
  return sorted(set(l), key=l.index)
250
 
251
- animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
252
  animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
253
  pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
254
  pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
@@ -289,11 +270,11 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
289
  return p.search(tag)
290
 
291
  un_tags = ['solo']
292
- group_list = ['metatags', 'other-lists', 'artists', 'more-2', 'objects', 'other-groups', 'games', 'plants', 'real-world', 'more-1', 'genres', 'characters', 'image-style', 'see-also', 'creatures', 'sex', 'attire-accessories', 'body']
293
  keep_group_dict = {
294
- "body": ['characters', 'body'],
295
- "dress": ['characters', 'attire-accessories', 'body'],
296
- "all": ['metatags', 'other-lists', 'artists', 'more-2', 'objects', 'other-groups', 'games', 'plants', 'real-world', 'more-1', 'genres', 'characters', 'image-style', 'see-also', 'creatures', 'sex', 'attire-accessories', 'body']
297
  }
298
 
299
  def is_necessary(tag, keep_tags, group_dict):
@@ -309,7 +290,7 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
309
  return True
310
 
311
  if keep_tags == "all": return input_prompt
312
- keep_group = keep_group_dict.get(keep_tags, ['characters', 'body'])
313
  explicit_group = list(set(group_list) ^ set(keep_group))
314
 
315
  tags = input_prompt.split(",") if input_prompt else []
@@ -329,6 +310,59 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
329
  return output_prompt
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
333
  results = {
334
  k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
@@ -400,7 +434,13 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
400
 
401
  return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
402
 
403
-
 
 
 
 
 
 
404
  def compose_prompt_to_copy(character: str, series: str, general: str):
405
  characters = character.split(",") if character else []
406
  serieses = series.split(",") if series else []
 
59
  return dict
60
 
61
 
62
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  def character_list_to_series_list(character_list):
 
229
  def list_uniq(l):
230
  return sorted(set(l), key=l.index)
231
 
232
+ animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
233
  animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
234
  pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
235
  pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
 
270
  return p.search(tag)
271
 
272
  un_tags = ['solo']
273
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
274
  keep_group_dict = {
275
+ "body": ['groups', 'body_parts'],
276
+ "dress": ['groups', 'body_parts', 'attire'],
277
+ "all": group_list,
278
  }
279
 
280
  def is_necessary(tag, keep_tags, group_dict):
 
290
  return True
291
 
292
  if keep_tags == "all": return input_prompt
293
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
294
  explicit_group = list(set(group_list) ^ set(keep_group))
295
 
296
  tags = input_prompt.split(",") if input_prompt else []
 
310
  return output_prompt
311
 
312
 
313
+ def sort_taglist(tags: list[str]):
314
+ if not tags: return []
315
+ character_tags: list[str] = []
316
+ series_tags: list[str] = []
317
+ people_tags: list[str] = []
318
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
319
+ group_tags = {}
320
+ other_tags: list[str] = []
321
+ rating_tags: list[str] = []
322
+
323
+ group_dict = tag_group_dict
324
+ group_set = set(group_dict.keys())
325
+ character_set = set(anime_series_dict.keys())
326
+ series_set = set(anime_series_dict.values())
327
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
328
+
329
+ for tag in tags:
330
+ tag = tag.strip().replace("_", " ")
331
+ if tag in PEOPLE_TAGS:
332
+ people_tags.append(tag)
333
+ elif tag in rating_set:
334
+ rating_tags.append(tag)
335
+ elif tag in group_set:
336
+ elem = group_dict[tag]
337
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
338
+ elif tag in character_set:
339
+ character_tags.append(tag)
340
+ elif tag in series_set:
341
+ series_tags.append(tag)
342
+ else:
343
+ other_tags.append(tag)
344
+
345
+ output_group_tags: list[str] = []
346
+ for k in group_list:
347
+ output_group_tags.extend(group_tags.get(k, []))
348
+
349
+ rating_tags = [rating_tags[0]] if rating_tags else []
350
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
351
+
352
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
353
+
354
+ return output_tags
355
+
356
+
357
+ def sort_tags(tags: str):
358
+ if not tags: return ""
359
+ taglist: list[str] = []
360
+ for tag in tags.split(","):
361
+ taglist.append(tag.strip())
362
+ taglist = list(filter(lambda x: x != "", taglist))
363
+ return ", ".join(sort_taglist(taglist))
364
+
365
+
366
  def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
367
  results = {
368
  k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
 
434
 
435
  return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
436
 
437
+
438
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
439
+ if algo and not "Use WD Tagger" in algo:
440
+ return "", "", input_tags, gr.update(interactive=True),
441
+ return predict_tags(image, general_threshold, character_threshold)
442
+
443
+
444
  def compose_prompt_to_copy(character: str, series: str, general: str):
445
  characters = character.split(",") if character else []
446
  serieses = series.split(",") if series else []