Spaces:
Running
Running
FoodDesert
commited on
Commit
•
f1da0db
1
Parent(s):
1c36512
Upload 3 files
Browse filesadding comma checking
- SquirrelIcon.png +0 -0
- app.py +68 -16
SquirrelIcon.png
ADDED
app.py
CHANGED
@@ -10,8 +10,7 @@ import re
|
|
10 |
import random
|
11 |
import compress_fasttext
|
12 |
from collections import OrderedDict
|
13 |
-
from lark import Lark
|
14 |
-
from lark import Token
|
15 |
from lark.exceptions import ParseError
|
16 |
import json
|
17 |
import zipfile
|
@@ -115,6 +114,19 @@ See SamplePrompts.csv for the list of prompts used and their descriptions.
|
|
115 |
|
116 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
grammar=r"""
|
119 |
!start: (prompt | /[][():]/+)*
|
120 |
prompt: (emphasized | plain | comma | WHITESPACE)*
|
@@ -125,6 +137,7 @@ WHITESPACE: /\s+/
|
|
125 |
plain: /([^,\\\[\]():|]|\\.)+/
|
126 |
%import common.SIGNED_NUMBER -> NUMBER
|
127 |
"""
|
|
|
128 |
# Initialize the parser
|
129 |
parser = Lark(grammar, start='start')
|
130 |
|
@@ -134,15 +147,14 @@ def extract_tags(tree):
|
|
134 |
def _traverse(node):
|
135 |
if isinstance(node, Token) and node.type == '__ANON_1':
|
136 |
tag_position = node.start_pos
|
137 |
-
#tag_text = node.value.strip()
|
138 |
tag_text = node.value
|
139 |
-
tags_with_positions.append((tag_text, tag_position))
|
140 |
elif not isinstance(node, Token):
|
141 |
for child in node.children:
|
142 |
_traverse(child)
|
143 |
_traverse(tree)
|
144 |
return tags_with_positions
|
145 |
-
|
146 |
|
147 |
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
148 |
def remove_special_tags(original_string):
|
@@ -384,11 +396,14 @@ def create_top_artists_table(top_artists):
|
|
384 |
return html_str
|
385 |
|
386 |
|
387 |
-
def create_html_placeholder(title="", placeholder_height=400, placeholder_width="100%"):
|
388 |
# Include a title in the same style as the top artists table heading
|
389 |
html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
|
|
|
|
|
|
|
390 |
# Add the placeholder div with specified height and width
|
391 |
-
html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px; background: transparent;'></div>"
|
392 |
return html_placeholder
|
393 |
|
394 |
|
@@ -420,6 +435,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
420 |
modified_tag = tag_info['modified_tag']
|
421 |
start_pos = tag_info['start_pos']
|
422 |
end_pos = tag_info['end_pos']
|
|
|
423 |
|
424 |
#print(original_tag, modified_tag, start_pos, end_pos)
|
425 |
|
@@ -432,6 +448,9 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
432 |
continue
|
433 |
encountered_modified_tags.add(modified_tag)
|
434 |
|
|
|
|
|
|
|
435 |
|
436 |
modified_tag_for_search = modified_tag.replace(' ','_')
|
437 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
@@ -471,12 +490,12 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
471 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
472 |
html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
473 |
|
474 |
-
bad_entities.append({"entity":"Unknown", "start":start_pos, "end":end_pos})
|
475 |
|
476 |
tags_added=True
|
477 |
# If no tags were processed, add a message
|
478 |
if not tags_added:
|
479 |
-
html_content = create_html_placeholder(title="Unknown Tags")
|
480 |
|
481 |
return html_content, bad_entities # Return list of lists for Dataframe
|
482 |
|
@@ -484,7 +503,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
484 |
def build_tag_offsets_dicts(new_image_tags_with_positions):
|
485 |
# Structure the data for HighlightedText
|
486 |
tag_data = []
|
487 |
-
for tag_text, start_pos in new_image_tags_with_positions:
|
488 |
# Modify the tag
|
489 |
modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
|
490 |
artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
|
@@ -496,10 +515,37 @@ def build_tag_offsets_dicts(new_image_tags_with_positions):
|
|
496 |
"start_pos": start_pos,
|
497 |
"end_pos": end_pos,
|
498 |
"modified_tag": modified_tag,
|
499 |
-
"artist_matrix_tag": artist_matrix_tag
|
|
|
500 |
})
|
501 |
return tag_data
|
502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
|
505 |
try:
|
@@ -508,26 +554,29 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
508 |
|
509 |
# Parse the prompt
|
510 |
parsed = parser.parse(new_tags_string)
|
|
|
511 |
# Extract tags from the parsed tree
|
512 |
new_image_tags = extract_tags(parsed)
|
|
|
513 |
tag_data = build_tag_offsets_dicts(new_image_tags)
|
514 |
|
515 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
516 |
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
517 |
|
|
|
|
|
518 |
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
519 |
-
#bad_tags_illustrated_string = {"text":original_tags_string, "entities":bad_entities}
|
520 |
|
521 |
#modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
|
522 |
#X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
|
523 |
-
artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data]
|
|
|
524 |
X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
|
525 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
526 |
|
527 |
top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
|
528 |
top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices if artist_names[i].lower() != "by conditional dnp"][:top_n]
|
529 |
|
530 |
-
#top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
|
531 |
top_artists_str = create_top_artists_table(top_artists)
|
532 |
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
|
533 |
|
@@ -538,7 +587,7 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
538 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
539 |
image_galleries.append(artists) # Extend the list with artist tuples
|
540 |
|
541 |
-
return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
542 |
except ParseError as e:
|
543 |
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
544 |
|
@@ -548,8 +597,11 @@ with gr.Blocks() as app:
|
|
548 |
with gr.Row():
|
549 |
with gr.Column(scale=3):
|
550 |
image_tags = gr.Textbox(label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, ...")
|
551 |
-
bad_tags_illustrated_string = gr.HighlightedText(show_legend=True,label="Annotated Prompt")
|
552 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
553 |
gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
|
554 |
submit_button = gr.Button("Submit")
|
555 |
with gr.Row():
|
|
|
10 |
import random
|
11 |
import compress_fasttext
|
12 |
from collections import OrderedDict
|
13 |
+
from lark import Lark, Tree, Token
|
|
|
14 |
from lark.exceptions import ParseError
|
15 |
import json
|
16 |
import zipfile
|
|
|
114 |
|
115 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
116 |
|
117 |
+
#grammar=r"""
|
118 |
+
#!start: (prompt | /[][():]/+)*
|
119 |
+
#prompt: (emphasized | plain | commas | WHITESPACE)*
|
120 |
+
#!emphasized: "(" prompt ")"
|
121 |
+
# | "(" prompt ":" [WHITESPACE] NUMBER [WHITESPACE] ")"
|
122 |
+
#!comma: ","
|
123 |
+
#commas: double_comma | comma
|
124 |
+
#double_comma: comma WHITESPACE* comma
|
125 |
+
#WHITESPACE: /\s+/
|
126 |
+
#plain: /([^,\\\[\]():|]|\\.)+/
|
127 |
+
#%import common.SIGNED_NUMBER -> NUMBER
|
128 |
+
#"""
|
129 |
+
|
130 |
grammar=r"""
|
131 |
!start: (prompt | /[][():]/+)*
|
132 |
prompt: (emphasized | plain | comma | WHITESPACE)*
|
|
|
137 |
plain: /([^,\\\[\]():|]|\\.)+/
|
138 |
%import common.SIGNED_NUMBER -> NUMBER
|
139 |
"""
|
140 |
+
|
141 |
# Initialize the parser
|
142 |
parser = Lark(grammar, start='start')
|
143 |
|
|
|
147 |
def _traverse(node):
|
148 |
if isinstance(node, Token) and node.type == '__ANON_1':
|
149 |
tag_position = node.start_pos
|
|
|
150 |
tag_text = node.value
|
151 |
+
tags_with_positions.append((tag_text, tag_position, "tag"))
|
152 |
elif not isinstance(node, Token):
|
153 |
for child in node.children:
|
154 |
_traverse(child)
|
155 |
_traverse(tree)
|
156 |
return tags_with_positions
|
157 |
+
|
158 |
|
159 |
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
160 |
def remove_special_tags(original_string):
|
|
|
396 |
return html_str
|
397 |
|
398 |
|
399 |
+
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
400 |
# Include a title in the same style as the top artists table heading
|
401 |
html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
|
402 |
+
# Conditionally add content if present
|
403 |
+
if content:
|
404 |
+
html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
|
405 |
# Add the placeholder div with specified height and width
|
406 |
+
html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
|
407 |
return html_placeholder
|
408 |
|
409 |
|
|
|
435 |
modified_tag = tag_info['modified_tag']
|
436 |
start_pos = tag_info['start_pos']
|
437 |
end_pos = tag_info['end_pos']
|
438 |
+
node_type = tag_info['node_type']
|
439 |
|
440 |
#print(original_tag, modified_tag, start_pos, end_pos)
|
441 |
|
|
|
448 |
continue
|
449 |
encountered_modified_tags.add(modified_tag)
|
450 |
|
451 |
+
if node_type == "double_comma":
|
452 |
+
bad_entities.append({"entity":"Double Comma", "start":start_pos, "end":end_pos})
|
453 |
+
continue
|
454 |
|
455 |
modified_tag_for_search = modified_tag.replace(' ','_')
|
456 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
|
|
490 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
491 |
html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
492 |
|
493 |
+
bad_entities.append({"entity":"Unknown Tag", "start":start_pos, "end":end_pos})
|
494 |
|
495 |
tags_added=True
|
496 |
# If no tags were processed, add a message
|
497 |
if not tags_added:
|
498 |
+
html_content = create_html_placeholder(title="Unknown Tags", content="No Unknown Tags Found")
|
499 |
|
500 |
return html_content, bad_entities # Return list of lists for Dataframe
|
501 |
|
|
|
503 |
def build_tag_offsets_dicts(new_image_tags_with_positions):
|
504 |
# Structure the data for HighlightedText
|
505 |
tag_data = []
|
506 |
+
for tag_text, start_pos, nodetype in new_image_tags_with_positions:
|
507 |
# Modify the tag
|
508 |
modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
|
509 |
artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
|
|
|
515 |
"start_pos": start_pos,
|
516 |
"end_pos": end_pos,
|
517 |
"modified_tag": modified_tag,
|
518 |
+
"artist_matrix_tag": artist_matrix_tag,
|
519 |
+
"node_type": nodetype
|
520 |
})
|
521 |
return tag_data
|
522 |
|
523 |
+
|
524 |
+
def augment_bad_entities_with_regex(text):
|
525 |
+
bad_entities = []
|
526 |
+
|
527 |
+
#comma at end
|
528 |
+
match = re.search(r',(?=\s*$)', text)
|
529 |
+
if match:
|
530 |
+
index = match.start()
|
531 |
+
bad_entities.append({"entity":"Remove Final Comma", "start":index, "end":index+1})
|
532 |
+
match = re.search(r'\([^()]*(,)\s*\)\s*$', text)
|
533 |
+
if match:
|
534 |
+
index = match.start(1)
|
535 |
+
bad_entities.append({"entity":"Remove Final Comma", "start":index, "end":index+1})
|
536 |
+
match = re.search(r'\([^()]*(,)\s*:\s*\d+(\.\d+)?\s*\)\s*$', text)
|
537 |
+
if match:
|
538 |
+
index = match.start(1)
|
539 |
+
bad_entities.append({"entity":"Remove Final Comma", "start":index, "end":index+1})
|
540 |
+
|
541 |
+
#comma after parentheses
|
542 |
+
match = re.search(r'\)\s*(,)\s*[^\s]',text)
|
543 |
+
if match:
|
544 |
+
index = match.start(1)
|
545 |
+
bad_entities.append({"entity":"Move Comma Inside Parentheses", "start":index, "end":index+1})
|
546 |
+
|
547 |
+
return bad_entities
|
548 |
+
|
549 |
|
550 |
def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
|
551 |
try:
|
|
|
554 |
|
555 |
# Parse the prompt
|
556 |
parsed = parser.parse(new_tags_string)
|
557 |
+
|
558 |
# Extract tags from the parsed tree
|
559 |
new_image_tags = extract_tags(parsed)
|
560 |
+
|
561 |
tag_data = build_tag_offsets_dicts(new_image_tags)
|
562 |
|
563 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
564 |
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
565 |
|
566 |
+
bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
|
567 |
+
bad_entities.sort(key=lambda x: x['start'])
|
568 |
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
|
|
569 |
|
570 |
#modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
|
571 |
#X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
|
572 |
+
#artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data]
|
573 |
+
artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
|
574 |
X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
|
575 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
576 |
|
577 |
top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
|
578 |
top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices if artist_names[i].lower() != "by conditional dnp"][:top_n]
|
579 |
|
|
|
580 |
top_artists_str = create_top_artists_table(top_artists)
|
581 |
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
|
582 |
|
|
|
587 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
588 |
image_galleries.append(artists) # Extend the list with artist tuples
|
589 |
|
590 |
+
return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
591 |
except ParseError as e:
|
592 |
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
593 |
|
|
|
597 |
with gr.Row():
|
598 |
with gr.Column(scale=3):
|
599 |
image_tags = gr.Textbox(label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, ...")
|
600 |
+
bad_tags_illustrated_string = gr.HighlightedText(show_legend=True, color_map={"Unknown Tag":"red","Duplicate":"yellow","Remove Final Comma":"purple","Move Comma Inside Parentheses":"green"}, label="Annotated Prompt")
|
601 |
with gr.Column(scale=1):
|
602 |
+
#gr.Image(label=" ", value="SquirrelIcon.png", height=155, width=140)
|
603 |
+
#image_path = os.path.join(os.getcwd(), "SquirrelIcon.png")
|
604 |
+
#gr.HTML('<div style="text-align: center;"><img src="{image_path}" alt="Cute Mascot" style="max-height: 100px; background: transparent;"></div><br>')
|
605 |
gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
|
606 |
submit_button = gr.Button("Submit")
|
607 |
with gr.Row():
|