FoodDesert commited on
Commit
f1da0db
1 Parent(s): 1c36512

Upload 3 files

Browse files

adding comma checking

Files changed (2) hide show
  1. SquirrelIcon.png +0 -0
  2. app.py +68 -16
SquirrelIcon.png ADDED
app.py CHANGED
@@ -10,8 +10,7 @@ import re
10
  import random
11
  import compress_fasttext
12
  from collections import OrderedDict
13
- from lark import Lark
14
- from lark import Token
15
  from lark.exceptions import ParseError
16
  import json
17
  import zipfile
@@ -115,6 +114,19 @@ See SamplePrompts.csv for the list of prompts used and their descriptions.
115
 
116
  nsfw_threshold = 0.95 # Assuming the threshold value is defined here
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  grammar=r"""
119
  !start: (prompt | /[][():]/+)*
120
  prompt: (emphasized | plain | comma | WHITESPACE)*
@@ -125,6 +137,7 @@ WHITESPACE: /\s+/
125
  plain: /([^,\\\[\]():|]|\\.)+/
126
  %import common.SIGNED_NUMBER -> NUMBER
127
  """
 
128
  # Initialize the parser
129
  parser = Lark(grammar, start='start')
130
 
@@ -134,15 +147,14 @@ def extract_tags(tree):
134
  def _traverse(node):
135
  if isinstance(node, Token) and node.type == '__ANON_1':
136
  tag_position = node.start_pos
137
- #tag_text = node.value.strip()
138
  tag_text = node.value
139
- tags_with_positions.append((tag_text, tag_position))
140
  elif not isinstance(node, Token):
141
  for child in node.children:
142
  _traverse(child)
143
  _traverse(tree)
144
  return tags_with_positions
145
-
146
 
147
  special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
148
  def remove_special_tags(original_string):
@@ -384,11 +396,14 @@ def create_top_artists_table(top_artists):
384
  return html_str
385
 
386
 
387
- def create_html_placeholder(title="", placeholder_height=400, placeholder_width="100%"):
388
  # Include a title in the same style as the top artists table heading
389
  html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
 
 
 
390
  # Add the placeholder div with specified height and width
391
- html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px; background: transparent;'></div>"
392
  return html_placeholder
393
 
394
 
@@ -420,6 +435,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
420
  modified_tag = tag_info['modified_tag']
421
  start_pos = tag_info['start_pos']
422
  end_pos = tag_info['end_pos']
 
423
 
424
  #print(original_tag, modified_tag, start_pos, end_pos)
425
 
@@ -432,6 +448,9 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
432
  continue
433
  encountered_modified_tags.add(modified_tag)
434
 
 
 
 
435
 
436
  modified_tag_for_search = modified_tag.replace(' ','_')
437
  similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
@@ -471,12 +490,12 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
471
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
472
  html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
473
 
474
- bad_entities.append({"entity":"Unknown", "start":start_pos, "end":end_pos})
475
 
476
  tags_added=True
477
  # If no tags were processed, add a message
478
  if not tags_added:
479
- html_content = create_html_placeholder(title="Unknown Tags")
480
 
481
  return html_content, bad_entities # Return list of lists for Dataframe
482
 
@@ -484,7 +503,7 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
484
  def build_tag_offsets_dicts(new_image_tags_with_positions):
485
  # Structure the data for HighlightedText
486
  tag_data = []
487
- for tag_text, start_pos in new_image_tags_with_positions:
488
  # Modify the tag
489
  modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
490
  artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
@@ -496,10 +515,37 @@ def build_tag_offsets_dicts(new_image_tags_with_positions):
496
  "start_pos": start_pos,
497
  "end_pos": end_pos,
498
  "modified_tag": modified_tag,
499
- "artist_matrix_tag": artist_matrix_tag
 
500
  })
501
  return tag_data
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
505
  try:
@@ -508,26 +554,29 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
508
 
509
  # Parse the prompt
510
  parsed = parser.parse(new_tags_string)
 
511
  # Extract tags from the parsed tree
512
  new_image_tags = extract_tags(parsed)
 
513
  tag_data = build_tag_offsets_dicts(new_image_tags)
514
 
515
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
516
  unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
517
 
 
 
518
  bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
519
- #bad_tags_illustrated_string = {"text":original_tags_string, "entities":bad_entities}
520
 
521
  #modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
522
  #X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
523
- artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data]
 
524
  X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
525
  similarities = cosine_similarity(X_new_image, X_artist)[0]
526
 
527
  top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
528
  top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices if artist_names[i].lower() != "by conditional dnp"][:top_n]
529
 
530
- #top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
531
  top_artists_str = create_top_artists_table(top_artists)
532
  dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
533
 
@@ -538,7 +587,7 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
538
  image_galleries.append(baseline) # Add baseline as its own gallery item
539
  image_galleries.append(artists) # Extend the list with artist tuples
540
 
541
- return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
542
  except ParseError as e:
543
  return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
544
 
@@ -548,8 +597,11 @@ with gr.Blocks() as app:
548
  with gr.Row():
549
  with gr.Column(scale=3):
550
  image_tags = gr.Textbox(label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, ...")
551
- bad_tags_illustrated_string = gr.HighlightedText(show_legend=True,label="Annotated Prompt")
552
  with gr.Column(scale=1):
 
 
 
553
  gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
554
  submit_button = gr.Button("Submit")
555
  with gr.Row():
 
10
  import random
11
  import compress_fasttext
12
  from collections import OrderedDict
13
+ from lark import Lark, Tree, Token
 
14
  from lark.exceptions import ParseError
15
  import json
16
  import zipfile
 
114
 
115
  nsfw_threshold = 0.95 # Assuming the threshold value is defined here
116
 
117
+ #grammar=r"""
118
+ #!start: (prompt | /[][():]/+)*
119
+ #prompt: (emphasized | plain | commas | WHITESPACE)*
120
+ #!emphasized: "(" prompt ")"
121
+ # | "(" prompt ":" [WHITESPACE] NUMBER [WHITESPACE] ")"
122
+ #!comma: ","
123
+ #commas: double_comma | comma
124
+ #double_comma: comma WHITESPACE* comma
125
+ #WHITESPACE: /\s+/
126
+ #plain: /([^,\\\[\]():|]|\\.)+/
127
+ #%import common.SIGNED_NUMBER -> NUMBER
128
+ #"""
129
+
130
  grammar=r"""
131
  !start: (prompt | /[][():]/+)*
132
  prompt: (emphasized | plain | comma | WHITESPACE)*
 
137
  plain: /([^,\\\[\]():|]|\\.)+/
138
  %import common.SIGNED_NUMBER -> NUMBER
139
  """
140
+
141
  # Initialize the parser
142
  parser = Lark(grammar, start='start')
143
 
 
147
  def _traverse(node):
148
  if isinstance(node, Token) and node.type == '__ANON_1':
149
  tag_position = node.start_pos
 
150
  tag_text = node.value
151
+ tags_with_positions.append((tag_text, tag_position, "tag"))
152
  elif not isinstance(node, Token):
153
  for child in node.children:
154
  _traverse(child)
155
  _traverse(tree)
156
  return tags_with_positions
157
+
158
 
159
  special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
160
  def remove_special_tags(original_string):
 
396
  return html_str
397
 
398
 
399
+ def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
400
  # Include a title in the same style as the top artists table heading
401
  html_placeholder = f"<div style='text-align: center;'><h1>{title}</h1></div>"
402
+ # Conditionally add content if present
403
+ if content:
404
+ html_placeholder += f"<div style='text-align: center; margin-bottom: 20px;'><p>{content}</p></div>"
405
  # Add the placeholder div with specified height and width
406
+ html_placeholder += f"<div style='height: {placeholder_height}px; width: {placeholder_width}; margin: 20px auto; background: transparent;'></div>"
407
  return html_placeholder
408
 
409
 
 
435
  modified_tag = tag_info['modified_tag']
436
  start_pos = tag_info['start_pos']
437
  end_pos = tag_info['end_pos']
438
+ node_type = tag_info['node_type']
439
 
440
  #print(original_tag, modified_tag, start_pos, end_pos)
441
 
 
448
  continue
449
  encountered_modified_tags.add(modified_tag)
450
 
451
+ if node_type == "double_comma":
452
+ bad_entities.append({"entity":"Double Comma", "start":start_pos, "end":end_pos})
453
+ continue
454
 
455
  modified_tag_for_search = modified_tag.replace(' ','_')
456
  similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
 
490
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
491
  html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
492
 
493
+ bad_entities.append({"entity":"Unknown Tag", "start":start_pos, "end":end_pos})
494
 
495
  tags_added=True
496
  # If no tags were processed, add a message
497
  if not tags_added:
498
+ html_content = create_html_placeholder(title="Unknown Tags", content="No Unknown Tags Found")
499
 
500
  return html_content, bad_entities # Return list of lists for Dataframe
501
 
 
503
  def build_tag_offsets_dicts(new_image_tags_with_positions):
504
  # Structure the data for HighlightedText
505
  tag_data = []
506
+ for tag_text, start_pos, nodetype in new_image_tags_with_positions:
507
  # Modify the tag
508
  modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
509
  artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
 
515
  "start_pos": start_pos,
516
  "end_pos": end_pos,
517
  "modified_tag": modified_tag,
518
+ "artist_matrix_tag": artist_matrix_tag,
519
+ "node_type": nodetype
520
  })
521
  return tag_data
522
 
523
+
524
+ def augment_bad_entities_with_regex(text):
525
+ bad_entities = []
526
+
527
+ #comma at end
528
+ match = re.search(r',(?=\s*$)', text)
529
+ if match:
530
+ index = match.start()
531
+ bad_entities.append({"entity":"Remove Final Comma", "start":index, "end":index+1})
532
+ match = re.search(r'\([^()]*(,)\s*\)\s*$', text)
533
+ if match:
534
+ index = match.start(1)
535
+ bad_entities.append({"entity":"Remove Final Comma", "start":index, "end":index+1})
536
+ match = re.search(r'\([^()]*(,)\s*:\s*\d+(\.\d+)?\s*\)\s*$', text)
537
+ if match:
538
+ index = match.start(1)
539
+ bad_entities.append({"entity":"Remove Final Comma", "start":index, "end":index+1})
540
+
541
+ #comma after parentheses
542
+ match = re.search(r'\)\s*(,)\s*[^\s]',text)
543
+ if match:
544
+ index = match.start(1)
545
+ bad_entities.append({"entity":"Move Comma Inside Parentheses", "start":index, "end":index+1})
546
+
547
+ return bad_entities
548
+
549
 
550
  def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
551
  try:
 
554
 
555
  # Parse the prompt
556
  parsed = parser.parse(new_tags_string)
557
+
558
  # Extract tags from the parsed tree
559
  new_image_tags = extract_tags(parsed)
560
+
561
  tag_data = build_tag_offsets_dicts(new_image_tags)
562
 
563
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
564
  unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
565
 
566
+ bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
567
+ bad_entities.sort(key=lambda x: x['start'])
568
  bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
 
569
 
570
  #modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
571
  #X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
572
+ #artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data]
573
+ artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
574
  X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
575
  similarities = cosine_similarity(X_new_image, X_artist)[0]
576
 
577
  top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
578
  top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices if artist_names[i].lower() != "by conditional dnp"][:top_n]
579
 
 
580
  top_artists_str = create_top_artists_table(top_artists)
581
  dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
582
 
 
587
  image_galleries.append(baseline) # Add baseline as its own gallery item
588
  image_galleries.append(artists) # Extend the list with artist tuples
589
 
590
+ return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
591
  except ParseError as e:
592
  return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
593
 
 
597
  with gr.Row():
598
  with gr.Column(scale=3):
599
  image_tags = gr.Textbox(label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, ...")
600
+ bad_tags_illustrated_string = gr.HighlightedText(show_legend=True, color_map={"Unknown Tag":"red","Duplicate":"yellow","Remove Final Comma":"purple","Move Comma Inside Parentheses":"green"}, label="Annotated Prompt")
601
  with gr.Column(scale=1):
602
+ #gr.Image(label=" ", value="SquirrelIcon.png", height=155, width=140)
603
+ #image_path = os.path.join(os.getcwd(), "SquirrelIcon.png")
604
+ #gr.HTML('<div style="text-align: center;"><img src="{image_path}" alt="Cute Mascot" style="max-height: 100px; background: transparent;"></div><br>')
605
  gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
606
  submit_button = gr.Button("Submit")
607
  with gr.Row():