Spaces:
Running
Running
FoodDesert
commited on
Commit
•
83610fc
1
Parent(s):
5b174ea
Upload 3 files
Browse files- README.md +1 -1
- app.py +90 -29
- tfidfreducedfiles.joblib +3 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🐿️
|
4 |
colorFrom: gray
|
5 |
colorTo: gray
|
|
|
1 |
---
|
2 |
+
title: Prompt Squirrel
|
3 |
emoji: 🐿️
|
4 |
colorFrom: gray
|
5 |
colorTo: gray
|
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
from scipy.sparse import csr_matrix
|
4 |
import numpy as np
|
|
|
5 |
from joblib import load
|
6 |
import h5py
|
7 |
from io import BytesIO
|
@@ -19,6 +20,7 @@ import io
|
|
19 |
import os
|
20 |
import glob
|
21 |
import itertools
|
|
|
22 |
|
23 |
|
24 |
|
@@ -32,7 +34,7 @@ Since Stable Diffusion's initial release in 2022, users have developed a myriad
|
|
32 |
Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
|
33 |
This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
|
34 |
|
35 |
-
When you enter a txt2img prompt and press the "submit" button,
|
36 |
If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
|
37 |
Additionally, in the "Top Artists" text box, it lists the artists who would most likely draw an image having the set of tags you provided.
|
38 |
This is useful to align your prompt with the expected input to an e621-trained model.
|
@@ -114,18 +116,12 @@ See SamplePrompts.csv for the list of prompts used and their descriptions.
|
|
114 |
|
115 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
#commas: double_comma | comma
|
124 |
-
#double_comma: comma WHITESPACE* comma
|
125 |
-
#WHITESPACE: /\s+/
|
126 |
-
#plain: /([^,\\\[\]():|]|\\.)+/
|
127 |
-
#%import common.SIGNED_NUMBER -> NUMBER
|
128 |
-
#"""
|
129 |
|
130 |
grammar=r"""
|
131 |
!start: (prompt | /[][():]/+)*
|
@@ -353,11 +349,11 @@ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix,
|
|
353 |
return geometric_mean
|
354 |
|
355 |
|
356 |
-
def create_html_tables_for_tags(
|
357 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
358 |
-
html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{
|
359 |
# Loop through the results and add table rows for each
|
360 |
-
for word, sim in
|
361 |
word_with_underscores = word.replace(' ', '_')
|
362 |
count = tag2count.get(word_with_underscores, 0) # Get the count if available, otherwise default to 0
|
363 |
tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
|
@@ -379,7 +375,7 @@ def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
|
|
379 |
|
380 |
def create_top_artists_table(top_artists):
|
381 |
# Add a heading above the table
|
382 |
-
html_str = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
383 |
html_str += "<h1>Top Artists</h1>" # Heading for the table
|
384 |
# Start the table with increased font size and no borders between rows
|
385 |
html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
|
@@ -396,16 +392,70 @@ def create_top_artists_table(top_artists):
|
|
396 |
return html_str
|
397 |
|
398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
400 |
# Include a title in the same style as the top artists table heading
|
401 |
-
html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
|
402 |
# Conditionally add content if present
|
403 |
if content:
|
404 |
html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
|
405 |
# Add the placeholder div with specified height and width
|
406 |
html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
|
407 |
return html_placeholder
|
408 |
-
|
409 |
|
410 |
def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
411 |
#Initialize stuff
|
@@ -425,7 +475,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
425 |
transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
|
426 |
|
427 |
# Find similar tags and prepare data for tables
|
428 |
-
html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
429 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
430 |
tags_added = False
|
431 |
bad_entities = []
|
@@ -561,14 +611,21 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
561 |
|
562 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
563 |
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
564 |
-
|
|
|
565 |
bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
|
566 |
bad_entities.sort(key=lambda x: x['start'])
|
567 |
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
568 |
|
569 |
-
#
|
570 |
-
|
571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
|
573 |
X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
|
574 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
@@ -586,12 +643,12 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
586 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
587 |
image_galleries.append(artists) # Extend the list with artist tuples
|
588 |
|
589 |
-
return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
590 |
except ParseError as e:
|
591 |
-
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
592 |
|
593 |
|
594 |
-
with gr.Blocks() as app:
|
595 |
with gr.Group():
|
596 |
with gr.Row():
|
597 |
with gr.Column(scale=3):
|
@@ -609,7 +666,11 @@ with gr.Blocks() as app:
|
|
609 |
with gr.Row():
|
610 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
611 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
612 |
-
|
|
|
|
|
|
|
|
|
613 |
with gr.Column(scale=1):
|
614 |
with gr.Group():
|
615 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
@@ -626,7 +687,7 @@ with gr.Blocks() as app:
|
|
626 |
submit_button.click(
|
627 |
find_similar_artists,
|
628 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
629 |
-
outputs=[unseen_tags, bad_tags_illustrated_string, top_artists, dynamic_prompts] + galleries
|
630 |
)
|
631 |
|
632 |
gr.Markdown(faq_content)
|
|
|
2 |
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
from scipy.sparse import csr_matrix
|
4 |
import numpy as np
|
5 |
+
import joblib
|
6 |
from joblib import load
|
7 |
import h5py
|
8 |
from io import BytesIO
|
|
|
20 |
import os
|
21 |
import glob
|
22 |
import itertools
|
23 |
+
from itertools import islice
|
24 |
|
25 |
|
26 |
|
|
|
34 |
Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
|
35 |
This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
|
36 |
|
37 |
+
When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
|
38 |
If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
|
39 |
Additionally, in the "Top Artists" text box, it lists the artists who would most likely draw an image having the set of tags you provided.
|
40 |
This is useful to align your prompt with the expected input to an e621-trained model.
|
|
|
116 |
|
117 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
118 |
|
119 |
+
css = """
|
120 |
+
.scrollable-content {
|
121 |
+
max-height: 500px;
|
122 |
+
overflow-y: auto;
|
123 |
+
}
|
124 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
grammar=r"""
|
127 |
!start: (prompt | /[][():]/+)*
|
|
|
349 |
return geometric_mean
|
350 |
|
351 |
|
352 |
+
def create_html_tables_for_tags(subtable_heading, word_similarity_tuples, tag2count, tag2idwiki):
|
353 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
354 |
+
html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{subtable_heading}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
|
355 |
# Loop through the results and add table rows for each
|
356 |
+
for word, sim in word_similarity_tuples:
|
357 |
word_with_underscores = word.replace(' ', '_')
|
358 |
count = tag2count.get(word_with_underscores, 0) # Get the count if available, otherwise default to 0
|
359 |
tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
|
|
|
375 |
|
376 |
def create_top_artists_table(top_artists):
|
377 |
# Add a heading above the table
|
378 |
+
html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
379 |
html_str += "<h1>Top Artists</h1>" # Heading for the table
|
380 |
# Start the table with increased font size and no borders between rows
|
381 |
html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
|
|
|
392 |
return html_str
|
393 |
|
394 |
|
395 |
+
def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded):
|
396 |
+
# Initialize a vector of zeros with the length of the term_to_index mapping
|
397 |
+
pseudo_vector = np.zeros(len(tag_to_row_loaded))
|
398 |
+
|
399 |
+
# Fill in the vector for terms in the pseudo document
|
400 |
+
for term in pseudo_doc_terms:
|
401 |
+
if term in tag_to_row_loaded:
|
402 |
+
index = tag_to_row_loaded[term]
|
403 |
+
pseudo_vector[index] = idf_loaded.get(term, 0)
|
404 |
+
|
405 |
+
# Return the vector as a 2D array for compatibility with SVD transform
|
406 |
+
return pseudo_vector.reshape(1, -1)
|
407 |
+
|
408 |
+
def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
409 |
+
# Compute cosine similarities
|
410 |
+
similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
|
411 |
+
|
412 |
+
# Get sorted tag indices based on similarities, in descending order
|
413 |
+
sorted_indices = np.argsort(-similarities)
|
414 |
+
|
415 |
+
# Return the top N indices
|
416 |
+
return sorted_indices
|
417 |
+
|
418 |
+
def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
|
419 |
+
# Check and load components if not already loaded
|
420 |
+
if not hasattr(get_tfidf_reduced_similar_tags, "components"):
|
421 |
+
get_tfidf_reduced_similar_tags.components = joblib.load('tfidfreducedfiles.joblib')
|
422 |
+
|
423 |
+
# Access components
|
424 |
+
components = get_tfidf_reduced_similar_tags.components
|
425 |
+
idf_loaded = components['idf']
|
426 |
+
tag_to_row_loaded = components['tag_to_row']
|
427 |
+
reduced_matrix_loaded = components['reduced_matrix']
|
428 |
+
svd_loaded = components['svd_model']
|
429 |
+
|
430 |
+
# Remaining part of the function
|
431 |
+
pseudo_vector = construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded)
|
432 |
+
reduced_pseudo_vector = svd_loaded.transform(pseudo_vector)
|
433 |
+
# Compute cosine similarities
|
434 |
+
similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix_loaded).flatten()
|
435 |
+
|
436 |
+
# Get top N indices based on similarities
|
437 |
+
top_indices_reduced = get_top_indices(reduced_pseudo_vector, reduced_matrix_loaded)
|
438 |
+
|
439 |
+
# Create the initial tag_similarity_dict
|
440 |
+
tag_similarity_dict = {list(tag_to_row_loaded.keys())[i]: similarities[i] for i in top_indices_reduced}
|
441 |
+
if not allow_nsfw_tags:
|
442 |
+
tag_similarity_dict = {tag: similarity for tag, similarity in tag_similarity_dict.items() if tag.replace(' ', '_') not in nsfw_tags}
|
443 |
+
|
444 |
+
sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
|
445 |
+
|
446 |
+
return sorted_tag_similarity_dict
|
447 |
+
|
448 |
+
|
449 |
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
450 |
# Include a title in the same style as the top artists table heading
|
451 |
+
html_placeholder = f"<div class=\"scrollable-content\" style='text-align: center;'><h1>{title}</h1></div>"
|
452 |
# Conditionally add content if present
|
453 |
if content:
|
454 |
html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
|
455 |
# Add the placeholder div with specified height and width
|
456 |
html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
|
457 |
return html_placeholder
|
458 |
+
|
459 |
|
460 |
def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
461 |
#Initialize stuff
|
|
|
475 |
transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
|
476 |
|
477 |
# Find similar tags and prepare data for tables
|
478 |
+
html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
479 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
480 |
tags_added = False
|
481 |
bad_entities = []
|
|
|
611 |
|
612 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
613 |
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
614 |
+
|
615 |
+
#Bad tags stuff
|
616 |
bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
|
617 |
bad_entities.sort(key=lambda x: x['start'])
|
618 |
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
619 |
|
620 |
+
#Suggested tags stuff
|
621 |
+
suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
622 |
+
|
623 |
+
suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
|
624 |
+
suggested_tags = get_tfidf_reduced_similar_tags([item["artist_matrix_tag"] for item in tag_data], allow_nsfw_tags)
|
625 |
+
topnsuggestions = list(islice(suggested_tags.items(), 100))
|
626 |
+
suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
627 |
+
|
628 |
+
#Artist stuff
|
629 |
artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
|
630 |
X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
|
631 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
|
|
643 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
644 |
image_galleries.append(artists) # Extend the list with artist tuples
|
645 |
|
646 |
+
return (unseen_tags_data, bad_tags_illustrated_string, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
647 |
except ParseError as e:
|
648 |
+
return [], "Parse Error: Check for mismatched parentheses or something", "", "", None, None
|
649 |
|
650 |
|
651 |
+
with gr.Blocks(css=css) as app:
|
652 |
with gr.Group():
|
653 |
with gr.Row():
|
654 |
with gr.Column(scale=3):
|
|
|
666 |
with gr.Row():
|
667 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
668 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
669 |
+
with gr.Row():
|
670 |
+
with gr.Column(scale=2):
|
671 |
+
unseen_tags = gr.HTML(label="Unknown Tags", value=create_html_placeholder(title="Unknown Tags"))
|
672 |
+
with gr.Column(scale=1):
|
673 |
+
suggested_tags = gr.HTML(label="Suggested Tags", value=create_html_placeholder(title="Suggested Tags"))
|
674 |
with gr.Column(scale=1):
|
675 |
with gr.Group():
|
676 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
|
|
687 |
submit_button.click(
|
688 |
find_similar_artists,
|
689 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
690 |
+
outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
|
691 |
)
|
692 |
|
693 |
gr.Markdown(faq_content)
|
tfidfreducedfiles.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a325f75a94c8a6c47034fba0e96a89039a3550463f916690b74c16d139f32504
|
3 |
+
size 68245886
|