SmilingWolf commited on
Commit
0c14216
1 Parent(s): 69cd139

Add support for SigLIP-trained weights.

Browse files

Same network structure for now, this is just to make it easier to
compare the two while experimenting.

app.py CHANGED
@@ -14,10 +14,13 @@ from Models.CLIP import CLIP
14
 
15
  def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
16
  pos = pos_img_embs + pos_tags_embs
 
17
 
18
  neg = neg_img_embs + neg_tags_embs
 
19
 
20
  result = pos - neg
 
21
  return result
22
 
23
 
@@ -48,12 +51,9 @@ def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
48
 
49
  class Predictor:
50
  def __init__(self):
 
51
  self.base_model = "wd-v1-4-convnext-tagger-v2"
52
 
53
- with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
54
- data = f.read()
55
-
56
- self.params = flax.serialization.msgpack_restore(data)["model"]
57
  self.model = CLIP()
58
 
59
  self.tags_df = pd.read_csv("data/selected_tags.csv")
@@ -64,12 +64,27 @@ class Predictor:
64
  config = json.loads(open("index/cosine_infos.json").read())["index_param"]
65
  faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def predict(
68
  self,
69
  pos_img_input,
70
  neg_img_input,
71
  positive_tags,
72
  negative_tags,
 
73
  selected_ratings,
74
  n_neighbours,
75
  api_username,
@@ -78,6 +93,8 @@ class Predictor:
78
  tags_df = self.tags_df
79
  model = self.model
80
 
 
 
81
  num_classes = len(tags_df)
82
 
83
  output_shape = model.out_units
@@ -172,10 +189,10 @@ def main():
172
  positive_tags = gr.Textbox(label="Positive tags")
173
  negative_tags = gr.Textbox(label="Negative tags")
174
  with gr.Column():
175
- selected_ratings = gr.CheckboxGroup(
176
- choices=["General", "Sensitive", "Questionable", "Explicit"],
177
- value=["General", "Sensitive"],
178
- label="Ratings",
179
  )
180
  n_neighbours = gr.Slider(
181
  minimum=1,
@@ -185,8 +202,14 @@ def main():
185
  label="# of images",
186
  )
187
  with gr.Column():
188
- api_username = gr.Textbox(label="Danbooru API Username")
189
- api_key = gr.Textbox(label="Danbooru API Key")
 
 
 
 
 
 
190
 
191
  find_btn = gr.Button("Find similar images")
192
 
@@ -199,6 +222,7 @@ def main():
199
  None,
200
  "marcille_donato",
201
  "",
 
202
  ["General", "Sensitive"],
203
  5,
204
  "",
@@ -209,6 +233,7 @@ def main():
209
  None,
210
  "yellow_eyes,red_horns",
211
  "",
 
212
  ["General", "Sensitive"],
213
  5,
214
  "",
@@ -219,6 +244,7 @@ def main():
219
  None,
220
  "artoria_pendragon_(fate),solo",
221
  "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
 
222
  ["General", "Sensitive"],
223
  5,
224
  "",
@@ -229,6 +255,7 @@ def main():
229
  None,
230
  "fujimaru_ritsuka_(female)",
231
  "solo",
 
232
  ["General", "Sensitive"],
233
  5,
234
  "",
@@ -239,6 +266,7 @@ def main():
239
  "examples/46657164_p1.jpg",
240
  "",
241
  "",
 
242
  ["General", "Sensitive"],
243
  5,
244
  "",
@@ -250,6 +278,7 @@ def main():
250
  neg_img_input,
251
  positive_tags,
252
  negative_tags,
 
253
  selected_ratings,
254
  n_neighbours,
255
  api_username,
@@ -268,6 +297,7 @@ def main():
268
  neg_img_input,
269
  positive_tags,
270
  negative_tags,
 
271
  selected_ratings,
272
  n_neighbours,
273
  api_username,
 
14
 
15
  def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
16
  pos = pos_img_embs + pos_tags_embs
17
+ faiss.normalize_L2(pos)
18
 
19
  neg = neg_img_embs + neg_tags_embs
20
+ faiss.normalize_L2(neg)
21
 
22
  result = pos - neg
23
+ faiss.normalize_L2(result)
24
  return result
25
 
26
 
 
51
 
52
  class Predictor:
53
  def __init__(self):
54
+ self.loaded_variant = None
55
  self.base_model = "wd-v1-4-convnext-tagger-v2"
56
 
 
 
 
 
57
  self.model = CLIP()
58
 
59
  self.tags_df = pd.read_csv("data/selected_tags.csv")
 
64
  config = json.loads(open("index/cosine_infos.json").read())["index_param"]
65
  faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
66
 
67
+ def load_params(self, variant):
68
+ if self.loaded_variant == variant:
69
+ return
70
+
71
+ if variant == "CLIP":
72
+ with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
73
+ data = f.read()
74
+ elif variant == "SigLIP":
75
+ with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f:
76
+ data = f.read()
77
+
78
+ self.params = flax.serialization.msgpack_restore(data)["model"]
79
+ self.loaded_variant = variant
80
+
81
  def predict(
82
  self,
83
  pos_img_input,
84
  neg_img_input,
85
  positive_tags,
86
  negative_tags,
87
+ selected_model,
88
  selected_ratings,
89
  n_neighbours,
90
  api_username,
 
93
  tags_df = self.tags_df
94
  model = self.model
95
 
96
+ self.load_params(selected_model)
97
+
98
  num_classes = len(tags_df)
99
 
100
  output_shape = model.out_units
 
189
  positive_tags = gr.Textbox(label="Positive tags")
190
  negative_tags = gr.Textbox(label="Negative tags")
191
  with gr.Column():
192
+ selected_model = gr.Radio(
193
+ choices=["CLIP", "SigLIP"],
194
+ value="CLIP",
195
+ label="Model",
196
  )
197
  n_neighbours = gr.Slider(
198
  minimum=1,
 
202
  label="# of images",
203
  )
204
  with gr.Column():
205
+ selected_ratings = gr.CheckboxGroup(
206
+ choices=["General", "Sensitive", "Questionable", "Explicit"],
207
+ value=["General", "Sensitive"],
208
+ label="Ratings",
209
+ )
210
+ with gr.Row():
211
+ api_username = gr.Textbox(label="Danbooru API Username")
212
+ api_key = gr.Textbox(label="Danbooru API Key")
213
 
214
  find_btn = gr.Button("Find similar images")
215
 
 
222
  None,
223
  "marcille_donato",
224
  "",
225
+ "CLIP",
226
  ["General", "Sensitive"],
227
  5,
228
  "",
 
233
  None,
234
  "yellow_eyes,red_horns",
235
  "",
236
+ "CLIP",
237
  ["General", "Sensitive"],
238
  5,
239
  "",
 
244
  None,
245
  "artoria_pendragon_(fate),solo",
246
  "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
247
+ "CLIP",
248
  ["General", "Sensitive"],
249
  5,
250
  "",
 
255
  None,
256
  "fujimaru_ritsuka_(female)",
257
  "solo",
258
+ "CLIP",
259
  ["General", "Sensitive"],
260
  5,
261
  "",
 
266
  "examples/46657164_p1.jpg",
267
  "",
268
  "",
269
+ "CLIP",
270
  ["General", "Sensitive"],
271
  5,
272
  "",
 
278
  neg_img_input,
279
  positive_tags,
280
  negative_tags,
281
+ selected_model,
282
  selected_ratings,
283
  n_neighbours,
284
  api_username,
 
297
  neg_img_input,
298
  positive_tags,
299
  negative_tags,
300
+ selected_model,
301
  selected_ratings,
302
  n_neighbours,
303
  api_username,
data/wd-v1-4-convnext-tagger-v2/siglip.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b541d6ed39a4df5ca2edd7e3431e936bbb61c9499026ad3365361af13aa06d06
3
+ size 48689369