FoodDesert commited on
Commit
83610fc
1 Parent(s): 5b174ea

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +90 -29
  3. tfidfreducedfiles.joblib +3 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Tagset Completer
3
  emoji: 🐿️
4
  colorFrom: gray
5
  colorTo: gray
 
1
  ---
2
+ title: Prompt Squirrel
3
  emoji: 🐿️
4
  colorFrom: gray
5
  colorTo: gray
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from sklearn.metrics.pairwise import cosine_similarity
3
  from scipy.sparse import csr_matrix
4
  import numpy as np
 
5
  from joblib import load
6
  import h5py
7
  from io import BytesIO
@@ -19,6 +20,7 @@ import io
19
  import os
20
  import glob
21
  import itertools
 
22
 
23
 
24
 
@@ -32,7 +34,7 @@ Since Stable Diffusion's initial release in 2022, users have developed a myriad
32
  Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
33
  This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
34
 
35
- When you enter a txt2img prompt and press the "submit" button, the Tagset Completer parses your prompt and checks that all your tags are valid e621 tags.
36
  If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
37
  Additionally, in the "Top Artists" text box, it lists the artists who would most likely draw an image having the set of tags you provided.
38
  This is useful to align your prompt with the expected input to an e621-trained model.
@@ -114,18 +116,12 @@ See SamplePrompts.csv for the list of prompts used and their descriptions.
114
 
115
  nsfw_threshold = 0.95 # Assuming the threshold value is defined here
116
 
117
- #grammar=r"""
118
- #!start: (prompt | /[][():]/+)*
119
- #prompt: (emphasized | plain | commas | WHITESPACE)*
120
- #!emphasized: "(" prompt ")"
121
- # | "(" prompt ":" [WHITESPACE] NUMBER [WHITESPACE] ")"
122
- #!comma: ","
123
- #commas: double_comma | comma
124
- #double_comma: comma WHITESPACE* comma
125
- #WHITESPACE: /\s+/
126
- #plain: /([^,\\\[\]():|]|\\.)+/
127
- #%import common.SIGNED_NUMBER -> NUMBER
128
- #"""
129
 
130
  grammar=r"""
131
  !start: (prompt | /[][():]/+)*
@@ -353,11 +349,11 @@ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix,
353
  return geometric_mean
354
 
355
 
356
- def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
357
  # Wrap the tag part in a <span> with styles for bold and larger font
358
- html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{tag}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
359
  # Loop through the results and add table rows for each
360
- for word, sim in result:
361
  word_with_underscores = word.replace(' ', '_')
362
  count = tag2count.get(word_with_underscores, 0) # Get the count if available, otherwise default to 0
363
  tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
@@ -379,7 +375,7 @@ def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
379
 
380
  def create_top_artists_table(top_artists):
381
  # Add a heading above the table
382
- html_str = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
383
  html_str += "<h1>Top Artists</h1>" # Heading for the table
384
  # Start the table with increased font size and no borders between rows
385
  html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
@@ -396,16 +392,70 @@ def create_top_artists_table(top_artists):
396
  return html_str
397
 
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
400
  # Include a title in the same style as the top artists table heading
401
- html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
402
  # Conditionally add content if present
403
  if content:
404
  html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
405
  # Add the placeholder div with specified height and width
406
  html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
407
  return html_placeholder
408
-
409
 
410
  def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
411
  #Initialize stuff
@@ -425,7 +475,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
425
  transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
426
 
427
  # Find similar tags and prepare data for tables
428
- html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
429
  html_content += "<h1>Unknown Tags</h1>" # Heading for the table
430
  tags_added = False
431
  bad_entities = []
@@ -561,14 +611,21 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
561
 
562
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
563
  unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
564
-
 
565
  bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
566
  bad_entities.sort(key=lambda x: x['start'])
567
  bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
568
 
569
- #modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
570
- #X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
571
- #artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data]
 
 
 
 
 
 
572
  artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
573
  X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
574
  similarities = cosine_similarity(X_new_image, X_artist)[0]
@@ -586,12 +643,12 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
586
  image_galleries.append(baseline) # Add baseline as its own gallery item
587
  image_galleries.append(artists) # Extend the list with artist tuples
588
 
589
- return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
590
  except ParseError as e:
591
- return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
592
 
593
 
594
- with gr.Blocks() as app:
595
  with gr.Group():
596
  with gr.Row():
597
  with gr.Column(scale=3):
@@ -609,7 +666,11 @@ with gr.Blocks() as app:
609
  with gr.Row():
610
  similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
611
  allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
612
- unseen_tags = gr.HTML(label="Unknown Tags", value=create_html_placeholder(title="Unknown Tags"))
 
 
 
 
613
  with gr.Column(scale=1):
614
  with gr.Group():
615
  num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
@@ -626,7 +687,7 @@ with gr.Blocks() as app:
626
  submit_button.click(
627
  find_similar_artists,
628
  inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
629
- outputs=[unseen_tags, bad_tags_illustrated_string, top_artists, dynamic_prompts] + galleries
630
  )
631
 
632
  gr.Markdown(faq_content)
 
2
  from sklearn.metrics.pairwise import cosine_similarity
3
  from scipy.sparse import csr_matrix
4
  import numpy as np
5
+ import joblib
6
  from joblib import load
7
  import h5py
8
  from io import BytesIO
 
20
  import os
21
  import glob
22
  import itertools
23
+ from itertools import islice
24
 
25
 
26
 
 
34
  Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
35
  This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
36
 
37
+ When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
38
  If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
39
  Additionally, in the "Top Artists" text box, it lists the artists who would most likely draw an image having the set of tags you provided.
40
  This is useful to align your prompt with the expected input to an e621-trained model.
 
116
 
117
  nsfw_threshold = 0.95 # Assuming the threshold value is defined here
118
 
119
+ css = """
120
+ .scrollable-content {
121
+ max-height: 500px;
122
+ overflow-y: auto;
123
+ }
124
+ """
 
 
 
 
 
 
125
 
126
  grammar=r"""
127
  !start: (prompt | /[][():]/+)*
 
349
  return geometric_mean
350
 
351
 
352
+ def create_html_tables_for_tags(subtable_heading, word_similarity_tuples, tag2count, tag2idwiki):
353
  # Wrap the tag part in a <span> with styles for bold and larger font
354
+ html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{subtable_heading}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
355
  # Loop through the results and add table rows for each
356
+ for word, sim in word_similarity_tuples:
357
  word_with_underscores = word.replace(' ', '_')
358
  count = tag2count.get(word_with_underscores, 0) # Get the count if available, otherwise default to 0
359
  tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
 
375
 
376
  def create_top_artists_table(top_artists):
377
  # Add a heading above the table
378
+ html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
379
  html_str += "<h1>Top Artists</h1>" # Heading for the table
380
  # Start the table with increased font size and no borders between rows
381
  html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
 
392
  return html_str
393
 
394
 
395
+ def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded):
396
+ # Initialize a vector of zeros with the length of the term_to_index mapping
397
+ pseudo_vector = np.zeros(len(tag_to_row_loaded))
398
+
399
+ # Fill in the vector for terms in the pseudo document
400
+ for term in pseudo_doc_terms:
401
+ if term in tag_to_row_loaded:
402
+ index = tag_to_row_loaded[term]
403
+ pseudo_vector[index] = idf_loaded.get(term, 0)
404
+
405
+ # Return the vector as a 2D array for compatibility with SVD transform
406
+ return pseudo_vector.reshape(1, -1)
407
+
408
+ def get_top_indices(reduced_pseudo_vector, reduced_matrix):
409
+ # Compute cosine similarities
410
+ similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
411
+
412
+ # Get sorted tag indices based on similarities, in descending order
413
+ sorted_indices = np.argsort(-similarities)
414
+
415
+ # Return the top N indices
416
+ return sorted_indices
417
+
418
+ def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
419
+ # Check and load components if not already loaded
420
+ if not hasattr(get_tfidf_reduced_similar_tags, "components"):
421
+ get_tfidf_reduced_similar_tags.components = joblib.load('tfidfreducedfiles.joblib')
422
+
423
+ # Access components
424
+ components = get_tfidf_reduced_similar_tags.components
425
+ idf_loaded = components['idf']
426
+ tag_to_row_loaded = components['tag_to_row']
427
+ reduced_matrix_loaded = components['reduced_matrix']
428
+ svd_loaded = components['svd_model']
429
+
430
+ # Remaining part of the function
431
+ pseudo_vector = construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded)
432
+ reduced_pseudo_vector = svd_loaded.transform(pseudo_vector)
433
+ # Compute cosine similarities
434
+ similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix_loaded).flatten()
435
+
436
+ # Get top N indices based on similarities
437
+ top_indices_reduced = get_top_indices(reduced_pseudo_vector, reduced_matrix_loaded)
438
+
439
+ # Create the initial tag_similarity_dict
440
+ tag_similarity_dict = {list(tag_to_row_loaded.keys())[i]: similarities[i] for i in top_indices_reduced}
441
+ if not allow_nsfw_tags:
442
+ tag_similarity_dict = {tag: similarity for tag, similarity in tag_similarity_dict.items() if tag.replace(' ', '_') not in nsfw_tags}
443
+
444
+ sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
445
+
446
+ return sorted_tag_similarity_dict
447
+
448
+
449
  def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
450
  # Include a title in the same style as the top artists table heading
451
+ html_placeholder = f"<div class=\"scrollable-content\" style='text-align: center;'><h1>{title}</h1></div>"
452
  # Conditionally add content if present
453
  if content:
454
  html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
455
  # Add the placeholder div with specified height and width
456
  html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
457
  return html_placeholder
458
+
459
 
460
  def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
461
  #Initialize stuff
 
475
  transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
476
 
477
  # Find similar tags and prepare data for tables
478
+ html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
479
  html_content += "<h1>Unknown Tags</h1>" # Heading for the table
480
  tags_added = False
481
  bad_entities = []
 
611
 
612
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
613
  unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
614
+
615
+ #Bad tags stuff
616
  bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
617
  bad_entities.sort(key=lambda x: x['start'])
618
  bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
619
 
620
+ #Suggested tags stuff
621
+ suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
622
+
623
+ suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
624
+ suggested_tags = get_tfidf_reduced_similar_tags([item["artist_matrix_tag"] for item in tag_data], allow_nsfw_tags)
625
+ topnsuggestions = list(islice(suggested_tags.items(), 100))
626
+ suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
627
+
628
+ #Artist stuff
629
  artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
630
  X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
631
  similarities = cosine_similarity(X_new_image, X_artist)[0]
 
643
  image_galleries.append(baseline) # Add baseline as its own gallery item
644
  image_galleries.append(artists) # Extend the list with artist tuples
645
 
646
+ return (unseen_tags_data, bad_tags_illustrated_string, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
647
  except ParseError as e:
648
+ return [], "Parse Error: Check for mismatched parentheses or something", "", "", None, None
649
 
650
 
651
+ with gr.Blocks(css=css) as app:
652
  with gr.Group():
653
  with gr.Row():
654
  with gr.Column(scale=3):
 
666
  with gr.Row():
667
  similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
668
  allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
669
+ with gr.Row():
670
+ with gr.Column(scale=2):
671
+ unseen_tags = gr.HTML(label="Unknown Tags", value=create_html_placeholder(title="Unknown Tags"))
672
+ with gr.Column(scale=1):
673
+ suggested_tags = gr.HTML(label="Suggested Tags", value=create_html_placeholder(title="Suggested Tags"))
674
  with gr.Column(scale=1):
675
  with gr.Group():
676
  num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
 
687
  submit_button.click(
688
  find_similar_artists,
689
  inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
690
+ outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
691
  )
692
 
693
  gr.Markdown(faq_content)
tfidfreducedfiles.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a325f75a94c8a6c47034fba0e96a89039a3550463f916690b74c16d139f32504
3
+ size 68245886