FoodDesert commited on
Commit
c442eda
1 Parent(s): 273251e

Upload 2 files

Browse files
Files changed (2) hide show
  1. ConvertSampleImagesToJpeg.ipynb +147 -0
  2. app.py +16 -100
ConvertSampleImagesToJpeg.ipynb ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "id": "4aa04654",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "098e115f",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import glob\n",
19
+ "import os\n",
20
+ "import json\n",
21
+ "from PIL import Image\n",
22
+ "from sd_parsers import ParserManager\n",
23
+ "\n",
24
+ "# Directory with PNG images\n",
25
+ "image_directory = 'E:/image/holder/Tagset_Completer/sampleimages/02landscape'\n",
26
+ "\n",
27
+ "# Initialize the ParserManager\n",
28
+ "parser_manager = ParserManager()\n",
29
+ "\n",
30
+ "# Dictionary for artist names to corresponding JPG file names\n",
31
+ "artist_to_file_map = {}\n",
32
+ "\n",
33
+ "# Iterate through PNG files in the directory\n",
34
+ "for png_file in glob.glob(os.path.join(image_directory, '*.png')):\n",
35
+ " with Image.open(png_file) as img:\n",
36
+ " # Extract metadata using ParserManager\n",
37
+ " prompt_info = parser_manager.parse(img)\n",
38
+ " if prompt_info and prompt_info.prompts:\n",
39
+ " first_prompt_text = list(prompt_info.prompts)[0].value.split(',')[0].strip()\n",
40
+ " if first_prompt_text.startswith(\"by \"):\n",
41
+ " first_prompt_text = first_prompt_text[3:] # Remove \"by \" prefix\n",
42
+ " artist_to_file_map[first_prompt_text] = os.path.basename(png_file).replace('.png', '.jpg')\n",
43
+ " else:\n",
44
+ " artist_to_file_map[\"\"] = os.path.basename(png_file).replace('.png', '.jpg')\n",
45
+ "\n",
46
+ "# Save the mapping to a JSON file in the same directory\n",
47
+ "json_path = os.path.join(image_directory, 'artist_to_file_map.json')\n",
48
+ "with open(json_path, 'w') as json_file:\n",
49
+ " json.dump(artist_to_file_map, json_file, indent=4)\n"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 2,
55
+ "id": "ac5cba7f",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "# Iterate through PNG files in the directory\n",
60
+ "for png_file in glob.glob(os.path.join(image_directory, '*.png')):\n",
61
+ " # Open the image\n",
62
+ " with Image.open(png_file) as img:\n",
63
+ " # Convert the image to RGB mode in case it's RGBA or P mode\n",
64
+ " img = img.convert('RGB')\n",
65
+ " # Define the output filename replacing .png with .jpg\n",
66
+ " jpg_file = png_file.rsplit('.', 1)[0] + '.jpg'\n",
67
+ " # Save the image in JPG format\n",
68
+ " img.save(jpg_file, 'JPEG')\n",
69
+ " # Optionally, remove the original PNG file\n",
70
+ " os.remove(png_file)\n"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "32bfb9cc",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": []
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "3648a9fc",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": []
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "09f74cbd",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "\n"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "id": "d2e18c17",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "354fda37",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": []
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "ac4e5911",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": []
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "Python 3 (ipykernel)",
129
+ "language": "python",
130
+ "name": "python3"
131
+ },
132
+ "language_info": {
133
+ "codemirror_mode": {
134
+ "name": "ipython",
135
+ "version": 3
136
+ },
137
+ "file_extension": ".py",
138
+ "mimetype": "text/x-python",
139
+ "name": "python",
140
+ "nbconvert_exporter": "python",
141
+ "pygments_lexer": "ipython3",
142
+ "version": "3.10.9"
143
+ }
144
+ },
145
+ "nbformat": 4,
146
+ "nbformat_minor": 5
147
+ }
app.py CHANGED
@@ -185,39 +185,10 @@ def load_model_components(file_path):
185
  model_components['row_to_tag'] = {idx: tag for tag, idx in model_components['tag_to_row_index'].items()}
186
 
187
  return model_components
 
188
  # Load all components at the start
189
  tf_idf_components = load_model_components('tf_idf_files_420.joblib')
190
 
191
- # Load the model and data once at startup
192
- with h5py.File('complete_artist_data.hdf5', 'r') as f:
193
- # Deserialize the vectorizer
194
- vectorizer_bytes = f['vectorizer'][()].tobytes()
195
- # Use io.BytesIO to convert bytes back to a file-like object for joblib to load
196
- vectorizer_buffer = BytesIO(vectorizer_bytes)
197
- vectorizer = load(vectorizer_buffer)
198
-
199
- # Load X_artist
200
- X_artist = f['X_artist'][:]
201
- # Load artist names and decode to strings
202
- artist_names = [name.decode() for name in f['artist_names'][:]]
203
-
204
-
205
- with h5py.File('conditional_tag_probabilities_matrix.h5', 'r') as f:
206
- # Reconstruct the sparse co-occurrence matrix
207
- conditional_co_occurrence_matrix = csr_matrix(
208
- (f['co_occurrence_data'][:], f['co_occurrence_indices'][:], f['co_occurrence_indptr'][:]),
209
- shape=f['co_occurrence_shape'][:]
210
- )
211
-
212
- # Reconstruct the vocabulary
213
- conditional_words = f['vocabulary_words'][:]
214
- conditional_indices = f['vocabulary_indices'][:]
215
- conditional_vocabulary = {key.decode('utf-8'): value for key, value in zip(conditional_words, conditional_indices)}
216
-
217
- # Load the document count
218
- conditional_doc_count = f['doc_count'][()]
219
- conditional_smoothing = 100. / conditional_doc_count
220
-
221
 
222
  nsfw_tags = set() # Initialize an empty set to store words meeting the threshold
223
  # Open and read the CSV file
@@ -349,50 +320,6 @@ def build_tag_id_wiki_dict(filename='wiki_pages-2023-08-08.csv'):
349
 
350
  return tag_data
351
 
352
-
353
- #Imagine we are adding smoothing_value to the number of times word_j occurs in each document for smoothing.
354
- #Note the intention is that sum_i(P(word_i|word_j)) =(approx) # of words in a document rather than 1.
355
- def conditional_probability(word_i, word_j, co_occurrence_matrix, vocabulary, doc_count, smoothing_value=0.01):
356
- word_i_index = vocabulary.get(word_i)
357
- word_j_index = vocabulary.get(word_j)
358
-
359
- if word_i_index is not None and word_j_index is not None:
360
- # Directly access the sparse matrix elements
361
- word_j_count = co_occurrence_matrix[word_j_index, word_j_index]
362
- smoothed_word_j_count = word_j_count + (smoothing_value * doc_count)
363
-
364
- word_i_count = co_occurrence_matrix[word_i_index, word_i_index]
365
-
366
- co_occurrence_count = co_occurrence_matrix[word_i_index, word_j_index]
367
- smoothed_co_occurrence_count = co_occurrence_count + (smoothing_value * word_i_count)
368
-
369
- # Calculate the conditional probability with smoothing
370
- conditional_prob = smoothed_co_occurrence_count / smoothed_word_j_count
371
-
372
- return conditional_prob
373
- elif word_i_index is None:
374
- return 0
375
- else:
376
- return None
377
-
378
-
379
- def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix, vocabulary, doc_count, smoothing_value=0.01):
380
- probabilities = []
381
-
382
- # Collect the conditional probabilities of the target word given each context word, ignoring None values
383
- for context_word in context_words:
384
- prob = conditional_probability(target_word, context_word, co_occurrence_matrix, vocabulary, doc_count, smoothing_value)
385
- if prob is not None:
386
- probabilities.append(prob)
387
-
388
- # Compute the geometric mean of the probabilities, avoiding division by zero
389
- if probabilities: # Check if the list is not empty
390
- geometric_mean = np.prod(probabilities) ** (1.0 / len(probabilities))
391
- else:
392
- geometric_mean = 0.5 # Or assign some default value if all probabilities are None
393
-
394
- return geometric_mean
395
-
396
 
397
  def create_html_tables_for_tags(subtable_heading, word_similarity_tuples, tag2count, tag2idwiki):
398
  # Wrap the tag part in a <span> with styles for bold and larger font
@@ -511,7 +438,7 @@ def create_html_placeholder(title="", content="", placeholder_height=400, placeh
511
  return html_placeholder
512
 
513
 
514
- def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
515
  #Initialize stuff
516
  if not hasattr(find_similar_tags, "fasttext_small_model"):
517
  find_similar_tags.fasttext_small_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load('e621FastTextModel010Replacement_small.bin')
@@ -584,10 +511,8 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
584
  #Adjust score based on context
585
  for i in range(len(result)):
586
  word, score = result[i] # Unpack the tuple
587
- geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != modified_tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
588
- adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
589
- result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
590
- #print(word, score, geometric_mean, adjusted_score)
591
 
592
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
593
  html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
@@ -650,32 +575,30 @@ def augment_bad_entities_with_regex(text):
650
  return bad_entities
651
 
652
 
653
- def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
654
  try:
655
  new_tags_string = original_tags_string.lower()
656
  new_tags_string, removed_tags = remove_special_tags(new_tags_string)
657
 
658
  # Parse the prompt
659
  parsed = parser.parse(new_tags_string)
660
-
661
  # Extract tags from the parsed tree
662
  new_image_tags = extract_tags(parsed)
663
-
664
  tag_data = build_tag_offsets_dicts(new_image_tags)
665
-
666
- ###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
667
- unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
 
 
 
 
 
 
668
 
669
  #Bad tags stuff
670
  bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
671
  bad_entities.sort(key=lambda x: x['start'])
672
  bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
673
-
674
- #Suggested tags stuff
675
- suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
676
-
677
- suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
678
- suggested_tags = get_tfidf_reduced_similar_tags([item["tf_idf_matrix_tag"] for item in tag_data] + removed_tags, allow_nsfw_tags)
679
 
680
  # Create a set of tags that should be filtered out
681
  filter_tags = {entry["original_tag"].strip() for entry in tag_data}
@@ -690,13 +613,6 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
690
  suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
691
 
692
  #Artist stuff
693
- #artist_matrix_tags = [tag_info['artist_matrix_tag'] for tag_info in tag_data if tag_info['node_type'] == "tag"]
694
- #X_new_image = vectorizer.transform([','.join(artist_matrix_tags + removed_tags)])
695
- #similarities = cosine_similarity(X_new_image, X_artist)[0]
696
- #
697
- #top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
698
- #top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices if artist_names[i].lower() != "by conditional dnp"][:top_n]
699
-
700
  excluded_artists = ["by conditional dnp", "by unknown artist"]
701
  top_artists = [(key, value) for key, value in suggested_artist_tags_filtered.items() if key.lower() not in excluded_artists][:top_n]
702
  top_artists_str = create_top_artists_table(top_artists)
@@ -737,7 +653,7 @@ with gr.Blocks(css=css) as app:
737
  with gr.Column(scale=3):
738
  with gr.Group():
739
  with gr.Row():
740
- similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
741
  allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
742
  with gr.Row():
743
  with gr.Column(scale=2):
@@ -759,7 +675,7 @@ with gr.Blocks(css=css) as app:
759
 
760
  submit_button.click(
761
  find_similar_artists,
762
- inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
763
  outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
764
  )
765
 
 
185
  model_components['row_to_tag'] = {idx: tag for tag, idx in model_components['tag_to_row_index'].items()}
186
 
187
  return model_components
188
+
189
  # Load all components at the start
190
  tf_idf_components = load_model_components('tf_idf_files_420.joblib')
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  nsfw_tags = set() # Initialize an empty set to store words meeting the threshold
194
  # Open and read the CSV file
 
320
 
321
  return tag_data
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  def create_html_tables_for_tags(subtable_heading, word_similarity_tuples, tag2count, tag2idwiki):
325
  # Wrap the tag part in a <span> with styles for bold and larger font
 
438
  return html_placeholder
439
 
440
 
441
+ def find_similar_tags(test_tags, tag_to_context_similarity, context_similarity_weight, allow_nsfw_tags):
442
  #Initialize stuff
443
  if not hasattr(find_similar_tags, "fasttext_small_model"):
444
  find_similar_tags.fasttext_small_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load('e621FastTextModel010Replacement_small.bin')
 
511
  #Adjust score based on context
512
  for i in range(len(result)):
513
  word, score = result[i] # Unpack the tuple
514
+ context_score = tag_to_context_similarity.get(word,0)
515
+ result[i] = (word, .5 * ((context_similarity_weight * context_score) + ((1 - context_similarity_weight) * score)))
 
 
516
 
517
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
518
  html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
 
575
  return bad_entities
576
 
577
 
578
+ def find_similar_artists(original_tags_string, top_n, context_similarity_weight, allow_nsfw_tags):
579
  try:
580
  new_tags_string = original_tags_string.lower()
581
  new_tags_string, removed_tags = remove_special_tags(new_tags_string)
582
 
583
  # Parse the prompt
584
  parsed = parser.parse(new_tags_string)
 
585
  # Extract tags from the parsed tree
586
  new_image_tags = extract_tags(parsed)
 
587
  tag_data = build_tag_offsets_dicts(new_image_tags)
588
+
589
+ #Suggested tags stuff
590
+ suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
591
+ suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
592
+ suggested_tags = get_tfidf_reduced_similar_tags([item["tf_idf_matrix_tag"] for item in tag_data] + removed_tags, allow_nsfw_tags)
593
+
594
+
595
+ unseen_tags_data, bad_entities = find_similar_tags(tag_data, suggested_tags, context_similarity_weight, allow_nsfw_tags)
596
+
597
 
598
  #Bad tags stuff
599
  bad_entities.extend(augment_bad_entities_with_regex(new_tags_string))
600
  bad_entities.sort(key=lambda x: x['start'])
601
  bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
 
 
 
 
 
 
602
 
603
  # Create a set of tags that should be filtered out
604
  filter_tags = {entry["original_tag"].strip() for entry in tag_data}
 
613
  suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
614
 
615
  #Artist stuff
 
 
 
 
 
 
 
616
  excluded_artists = ["by conditional dnp", "by unknown artist"]
617
  top_artists = [(key, value) for key, value in suggested_artist_tags_filtered.items() if key.lower() not in excluded_artists][:top_n]
618
  top_artists_str = create_top_artists_table(top_artists)
 
653
  with gr.Column(scale=3):
654
  with gr.Group():
655
  with gr.Row():
656
+ context_similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Context Similarity Weight")
657
  allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
658
  with gr.Row():
659
  with gr.Column(scale=2):
 
675
 
676
  submit_button.click(
677
  find_similar_artists,
678
+ inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
679
  outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
680
  )
681