FoodDesert commited on
Commit
e7aeeed
1 Parent(s): cb15d1f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -22
app.py CHANGED
@@ -130,16 +130,17 @@ parser = Lark(grammar, start='start')
130
 
131
  # Function to extract tags
132
  def extract_tags(tree):
133
- tags = []
134
  def _traverse(node):
135
  if isinstance(node, Token) and node.type == '__ANON_1':
136
- tags.append(node.value.strip())
 
 
137
  elif not isinstance(node, Token):
138
  for child in node.children:
139
  _traverse(child)
140
-
141
  _traverse(tree)
142
- return tags
143
 
144
 
145
  special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
@@ -341,7 +342,7 @@ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix,
341
 
342
  def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
343
  # Wrap the tag part in a <span> with styles for bold and larger font
344
- html_str = f"<div style='display: inline-block; margin: 10px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{tag}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
345
  # Loop through the results and add table rows for each
346
  for word, sim in result:
347
  word_with_underscores = word.replace(' ', '_')
@@ -404,24 +405,35 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
404
  if not hasattr(find_similar_tags, "tag2idwiki"):
405
  find_similar_tags.tag2idwiki = build_tag_id_wiki_dict()
406
 
407
- transformed_tags = [tag.replace(' ', '_') for tag in test_tags]
 
408
 
409
  # Find similar tags and prepare data for tables
410
  html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
411
  html_content += "<h1>Unknown Tags</h1>" # Heading for the table
412
  tags_added = False
413
- for tag in test_tags:
414
- if tag in special_tags:
 
 
 
 
 
 
 
 
 
 
415
  continue
416
 
417
- modified_tag_for_search = tag.replace(' ','_')
418
  similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
419
  result, seen = [], set(transformed_tags)
420
 
421
  if modified_tag_for_search in find_similar_tags.tag2aliases:
422
- if tag in find_similar_tags.tag2aliases and "_" in tag: #Implicitly tell the user that they should get rid of the underscore
423
  result.append(modified_tag_for_search.replace('_',' '), 1)
424
- seen.add(tag)
425
  else: #The user correctly did not put underscores in their tag
426
  continue
427
  else:
@@ -444,36 +456,60 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
444
  #Adjust score based on context
445
  for i in range(len(result)):
446
  word, score = result[i] # Unpack the tuple
447
- geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
448
  adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
449
  result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
450
  #print(word, score, geometric_mean, adjusted_score)
451
 
452
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
453
- html_content += create_html_tables_for_tags(tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
 
 
 
454
  tags_added=True
455
  # If no tags were processed, add a message
456
  if not tags_added:
457
  html_content = create_html_placeholder(title="Unknown Tags")
458
 
459
- return html_content # Return list of lists for Dataframe
460
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- def find_similar_artists(new_tags_string, top_n, similarity_weight, allow_nsfw_tags):
463
  try:
464
- new_tags_string = new_tags_string.lower()
465
  new_tags_string, removed_tags = remove_special_tags(new_tags_string)
466
 
467
  # Parse the prompt
468
  parsed = parser.parse(new_tags_string)
469
  # Extract tags from the parsed tree
470
  new_image_tags = extract_tags(parsed)
471
- new_image_tags = [tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip() for tag in new_image_tags]
472
 
473
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
474
- unseen_tags_data = find_similar_tags(new_image_tags, similarity_weight, allow_nsfw_tags)
 
 
475
 
476
- X_new_image = vectorizer.transform([','.join(new_image_tags + removed_tags)])
 
477
  similarities = cosine_similarity(X_new_image, X_artist)[0]
478
 
479
  top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
@@ -490,7 +526,7 @@ def find_similar_artists(new_tags_string, top_n, similarity_weight, allow_nsfw_t
490
  image_galleries.append(baseline) # Add baseline as its own gallery item
491
  image_galleries.append(artists) # Extend the list with artist tuples
492
 
493
- return (unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
494
  except ParseError as e:
495
  return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
496
 
@@ -504,6 +540,8 @@ with gr.Blocks() as app:
504
  similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
505
  num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
506
  allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
 
 
507
  with gr.Row():
508
  with gr.Column(scale=1):
509
  top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
@@ -521,7 +559,7 @@ with gr.Blocks() as app:
521
  submit_button.click(
522
  find_similar_artists,
523
  inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
524
- outputs=[unseen_tags, top_artists, dynamic_prompts] + galleries
525
  )
526
 
527
  gr.Markdown(faq_content)
 
130
 
131
  # Function to extract tags
132
  def extract_tags(tree):
133
+ tags_with_positions = []
134
  def _traverse(node):
135
  if isinstance(node, Token) and node.type == '__ANON_1':
136
+ tag_position = node.start_pos
137
+ tag_text = node.value.strip()
138
+ tags_with_positions.append((tag_text, tag_position))
139
  elif not isinstance(node, Token):
140
  for child in node.children:
141
  _traverse(child)
 
142
  _traverse(tree)
143
+ return tags_with_positions
144
 
145
 
146
  special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
 
342
 
343
  def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
344
  # Wrap the tag part in a <span> with styles for bold and larger font
345
+ html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{tag}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
346
  # Loop through the results and add table rows for each
347
  for word, sim in result:
348
  word_with_underscores = word.replace(' ', '_')
 
405
  if not hasattr(find_similar_tags, "tag2idwiki"):
406
  find_similar_tags.tag2idwiki = build_tag_id_wiki_dict()
407
 
408
+ modified_tags = [tag_info['modified_tag'] for tag_info in test_tags]
409
+ transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
410
 
411
  # Find similar tags and prepare data for tables
412
  html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
413
  html_content += "<h1>Unknown Tags</h1>" # Heading for the table
414
  tags_added = False
415
+ bad_entities = []
416
+ for tag_info in test_tags:
417
+ original_tag = tag_info['original_tag']
418
+ modified_tag = tag_info['modified_tag']
419
+ start_pos = tag_info['start_pos']
420
+ end_pos = tag_info['end_pos']
421
+
422
+
423
+ print(original_tag, modified_tag, start_pos, end_pos)
424
+
425
+
426
+ if modified_tag in special_tags:
427
  continue
428
 
429
+ modified_tag_for_search = modified_tag.replace(' ','_')
430
  similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
431
  result, seen = [], set(transformed_tags)
432
 
433
  if modified_tag_for_search in find_similar_tags.tag2aliases:
434
+ if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
435
  result.append(modified_tag_for_search.replace('_',' '), 1)
436
+ seen.add(modified_tag)
437
  else: #The user correctly did not put underscores in their tag
438
  continue
439
  else:
 
456
  #Adjust score based on context
457
  for i in range(len(result)):
458
  word, score = result[i] # Unpack the tuple
459
+ geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != modified_tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
460
  adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
461
  result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
462
  #print(word, score, geometric_mean, adjusted_score)
463
 
464
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
465
+ html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
466
+
467
+ bad_entities.append({"entity":"UNKNOWN", "start":start_pos, "end":end_pos})
468
+
469
  tags_added=True
470
  # If no tags were processed, add a message
471
  if not tags_added:
472
  html_content = create_html_placeholder(title="Unknown Tags")
473
 
474
+ return html_content, bad_entities # Return list of lists for Dataframe
475
+
476
+
477
+ def build_tag_offsets_dicts(new_image_tags_with_positions):
478
+ # Structure the data for HighlightedText
479
+ tag_data = []
480
+ for tag_text, start_pos in new_image_tags_with_positions:
481
+ # Modify the tag
482
+ modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
483
+ # Calculate the end position based on the original tag length
484
+ end_pos = start_pos + len(tag_text)
485
+ # Append the structured data for each tag
486
+ tag_data.append({
487
+ "original_tag": tag_text,
488
+ "start_pos": start_pos,
489
+ "end_pos": end_pos,
490
+ "modified_tag": modified_tag
491
+ })
492
+ return tag_data
493
+
494
 
495
+ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
496
  try:
497
+ new_tags_string = original_tags_string.lower()
498
  new_tags_string, removed_tags = remove_special_tags(new_tags_string)
499
 
500
  # Parse the prompt
501
  parsed = parser.parse(new_tags_string)
502
  # Extract tags from the parsed tree
503
  new_image_tags = extract_tags(parsed)
504
+ tag_data = build_tag_offsets_dicts(new_image_tags)
505
 
506
  ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
507
+ unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
508
+
509
+ bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
510
 
511
+ modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
512
+ X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
513
  similarities = cosine_similarity(X_new_image, X_artist)[0]
514
 
515
  top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
 
526
  image_galleries.append(baseline) # Add baseline as its own gallery item
527
  image_galleries.append(artists) # Extend the list with artist tuples
528
 
529
+ return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
530
  except ParseError as e:
531
  return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
532
 
 
540
  similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
541
  num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
542
  allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
543
+ with gr.Row():
544
+ bad_tags_illustrated_string = gr.HighlightedText()
545
  with gr.Row():
546
  with gr.Column(scale=1):
547
  top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
 
559
  submit_button.click(
560
  find_similar_artists,
561
  inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
562
+ outputs=[unseen_tags, bad_tags_illustrated_string, top_artists, dynamic_prompts] + galleries
563
  )
564
 
565
  gr.Markdown(faq_content)