SmilingWolf commited on
Commit
0bd8f65
1 Parent(s): 8444c60

Add image support

Browse files
.gitattributes CHANGED
@@ -34,3 +34,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.index filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.index filter=lfs diff=lfs merge=lfs -text
37
+
38
+ # Byte-compiled / optimized / DLL files
39
+ __pycache__/
40
+ *.py[cod]
41
+ *$py.class
app.py CHANGED
@@ -7,10 +7,20 @@ import jax
7
  import numpy as np
8
  import pandas as pd
9
  import requests
 
10
 
11
  from Models.CLIP import CLIP
12
 
13
 
 
 
 
 
 
 
 
 
 
14
  def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
15
  headers = {"User-Agent": "image_similarity_tool"}
16
  ratings_to_letters = {
@@ -56,6 +66,8 @@ class Predictor:
56
 
57
  def predict(
58
  self,
 
 
59
  positive_tags,
60
  negative_tags,
61
  selected_ratings,
@@ -68,38 +80,68 @@ class Predictor:
68
 
69
  num_classes = len(tags_df)
70
 
 
 
 
 
 
 
71
  positive_tags = positive_tags.split(",")
72
  negative_tags = negative_tags.split(",")
73
 
74
  positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
75
  negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
76
 
77
- tags = np.zeros((1, num_classes), dtype=np.float32)
78
- tags[0][positive_tags_idxs] = 1
79
- emb_from_logits = model.apply(
80
- {"params": self.params},
81
- tags,
82
- method=model.encode_text,
83
- )
84
- emb_from_logits = jax.device_get(emb_from_logits)
85
- faiss.normalize_L2(emb_from_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  if len(negative_tags_idxs) > 0:
88
  tags = np.zeros((1, num_classes), dtype=np.float32)
89
  tags[0][negative_tags_idxs] = 1
90
 
91
- neg_emb_from_logits = model.apply(
92
  {"params": self.params},
93
  tags,
94
  method=model.encode_text,
95
  )
96
- neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
97
- faiss.normalize_L2(neg_emb_from_logits)
98
-
99
- emb_from_logits = emb_from_logits - neg_emb_from_logits
100
- faiss.normalize_L2(emb_from_logits)
 
 
 
 
101
 
102
- dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
103
  neighbours_ids = self.images_ids[indexes][0]
104
  neighbours_ids = [int(x) for x in neighbours_ids]
105
 
@@ -122,10 +164,19 @@ def main():
122
  predictor = Predictor()
123
 
124
  with gr.Blocks() as demo:
 
 
 
125
  with gr.Row():
126
  with gr.Column():
127
  positive_tags = gr.Textbox(label="Positive tags")
128
  negative_tags = gr.Textbox(label="Negative tags")
 
 
 
 
 
 
129
  n_neighbours = gr.Slider(
130
  minimum=1,
131
  maximum=20,
@@ -133,15 +184,10 @@ def main():
133
  step=1,
134
  label="# of images",
135
  )
136
-
137
  with gr.Column():
138
  api_username = gr.Textbox(label="Danbooru API Username")
139
  api_key = gr.Textbox(label="Danbooru API Key")
140
- selected_ratings = gr.CheckboxGroup(
141
- choices=["General", "Sensitive", "Questionable", "Explicit"],
142
- value=["General", "Sensitive"],
143
- label="Ratings",
144
- )
145
  find_btn = gr.Button("Find similar images")
146
 
147
  similar_images = gr.Gallery(label="Similar images", columns=[5])
@@ -149,6 +195,8 @@ def main():
149
  examples = gr.Examples(
150
  [
151
  [
 
 
152
  "marcille_donato",
153
  "",
154
  ["General", "Sensitive"],
@@ -157,6 +205,8 @@ def main():
157
  "",
158
  ],
159
  [
 
 
160
  "yellow_eyes,red_horns",
161
  "",
162
  ["General", "Sensitive"],
@@ -165,6 +215,8 @@ def main():
165
  "",
166
  ],
167
  [
 
 
168
  "artoria_pendragon_(fate),solo",
169
  "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
170
  ["General", "Sensitive"],
@@ -172,8 +224,30 @@ def main():
172
  "",
173
  "",
174
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  ],
176
  inputs=[
 
 
177
  positive_tags,
178
  negative_tags,
179
  selected_ratings,
@@ -190,6 +264,8 @@ def main():
190
  find_btn.click(
191
  fn=predictor.predict,
192
  inputs=[
 
 
193
  positive_tags,
194
  negative_tags,
195
  selected_ratings,
 
7
  import numpy as np
8
  import pandas as pd
9
  import requests
10
+ from imgutils.tagging import wd14
11
 
12
  from Models.CLIP import CLIP
13
 
14
 
15
+ def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
16
+ pos = pos_img_embs + pos_tags_embs
17
+
18
+ neg = neg_img_embs + neg_tags_embs
19
+
20
+ result = pos - neg
21
+ return result
22
+
23
+
24
  def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
25
  headers = {"User-Agent": "image_similarity_tool"}
26
  ratings_to_letters = {
 
66
 
67
  def predict(
68
  self,
69
+ pos_img_input,
70
+ neg_img_input,
71
  positive_tags,
72
  negative_tags,
73
  selected_ratings,
 
80
 
81
  num_classes = len(tags_df)
82
 
83
+ output_shape = model.out_units
84
+ pos_img_embs = np.zeros((1, output_shape), dtype=np.float32)
85
+ neg_img_embs = np.zeros((1, output_shape), dtype=np.float32)
86
+ pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
87
+ neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
88
+
89
  positive_tags = positive_tags.split(",")
90
  negative_tags = negative_tags.split(",")
91
 
92
  positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
93
  negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
94
 
95
+ if pos_img_input is not None:
96
+ pos_img_embs = wd14.get_wd14_tags(
97
+ pos_img_input,
98
+ model_name="ConvNext",
99
+ fmt=("embedding"),
100
+ )
101
+ pos_img_embs = np.expand_dims(pos_img_embs, 0)
102
+ faiss.normalize_L2(pos_img_embs)
103
+
104
+ if neg_img_input is not None:
105
+ neg_img_embs = wd14.get_wd14_tags(
106
+ neg_img_input,
107
+ model_name="ConvNext",
108
+ fmt=("embedding"),
109
+ )
110
+ neg_img_embs = np.expand_dims(neg_img_embs, 0)
111
+ faiss.normalize_L2(neg_img_embs)
112
+
113
+ if len(positive_tags_idxs) > 0:
114
+ tags = np.zeros((1, num_classes), dtype=np.float32)
115
+ tags[0][positive_tags_idxs] = 1
116
+
117
+ pos_tags_embs = model.apply(
118
+ {"params": self.params},
119
+ tags,
120
+ method=model.encode_text,
121
+ )
122
+ pos_tags_embs = jax.device_get(pos_tags_embs)
123
+ faiss.normalize_L2(pos_tags_embs)
124
 
125
  if len(negative_tags_idxs) > 0:
126
  tags = np.zeros((1, num_classes), dtype=np.float32)
127
  tags[0][negative_tags_idxs] = 1
128
 
129
+ neg_tags_embs = model.apply(
130
  {"params": self.params},
131
  tags,
132
  method=model.encode_text,
133
  )
134
+ neg_tags_embs = jax.device_get(neg_tags_embs)
135
+ faiss.normalize_L2(neg_tags_embs)
136
+
137
+ embeddings = combine_embeddings(
138
+ pos_img_embs,
139
+ pos_tags_embs,
140
+ neg_img_embs,
141
+ neg_tags_embs,
142
+ )
143
 
144
+ dists, indexes = self.knn_index.search(embeddings, k=n_neighbours)
145
  neighbours_ids = self.images_ids[indexes][0]
146
  neighbours_ids = [int(x) for x in neighbours_ids]
147
 
 
164
  predictor = Predictor()
165
 
166
  with gr.Blocks() as demo:
167
+ with gr.Row():
168
+ pos_img_input = gr.Image(type="pil", label="Positive input")
169
+ neg_img_input = gr.Image(type="pil", label="Negative input")
170
  with gr.Row():
171
  with gr.Column():
172
  positive_tags = gr.Textbox(label="Positive tags")
173
  negative_tags = gr.Textbox(label="Negative tags")
174
+ with gr.Column():
175
+ selected_ratings = gr.CheckboxGroup(
176
+ choices=["General", "Sensitive", "Questionable", "Explicit"],
177
+ value=["General", "Sensitive"],
178
+ label="Ratings",
179
+ )
180
  n_neighbours = gr.Slider(
181
  minimum=1,
182
  maximum=20,
 
184
  step=1,
185
  label="# of images",
186
  )
 
187
  with gr.Column():
188
  api_username = gr.Textbox(label="Danbooru API Username")
189
  api_key = gr.Textbox(label="Danbooru API Key")
190
+
 
 
 
 
191
  find_btn = gr.Button("Find similar images")
192
 
193
  similar_images = gr.Gallery(label="Similar images", columns=[5])
 
195
  examples = gr.Examples(
196
  [
197
  [
198
+ None,
199
+ None,
200
  "marcille_donato",
201
  "",
202
  ["General", "Sensitive"],
 
205
  "",
206
  ],
207
  [
208
+ None,
209
+ None,
210
  "yellow_eyes,red_horns",
211
  "",
212
  ["General", "Sensitive"],
 
215
  "",
216
  ],
217
  [
218
+ None,
219
+ None,
220
  "artoria_pendragon_(fate),solo",
221
  "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
222
  ["General", "Sensitive"],
 
224
  "",
225
  "",
226
  ],
227
+ [
228
+ "examples/60378883_p0.jpg",
229
+ None,
230
+ "fujimaru_ritsuka_(female)",
231
+ "solo",
232
+ ["General", "Sensitive"],
233
+ 5,
234
+ "",
235
+ "",
236
+ ],
237
+ [
238
+ "examples/DaRlExxUwAAcUOS-orig.jpg",
239
+ "examples/46657164_p1.jpg",
240
+ "",
241
+ "",
242
+ ["General", "Sensitive"],
243
+ 5,
244
+ "",
245
+ "",
246
+ ],
247
  ],
248
  inputs=[
249
+ pos_img_input,
250
+ neg_img_input,
251
  positive_tags,
252
  negative_tags,
253
  selected_ratings,
 
264
  find_btn.click(
265
  fn=predictor.predict,
266
  inputs=[
267
+ pos_img_input,
268
+ neg_img_input,
269
  positive_tags,
270
  negative_tags,
271
  selected_ratings,
examples/46657164_p1.jpg ADDED
examples/60378883_p0.jpg ADDED
examples/DaRlExxUwAAcUOS-orig.jpg ADDED
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  faiss-cpu
2
  jax[cpu]
3
  flax
 
 
 
1
  faiss-cpu
2
  jax[cpu]
3
  flax
4
+ imgutils
5
+ onnxruntime