SmilingWolf
commited on
Commit
•
0c14216
1
Parent(s):
69cd139
Add support for SigLIP-trained weights.
Browse filesSame network structure for now, this is just to make it easier to
compare the two while experimenting.
- app.py +40 -10
- data/wd-v1-4-convnext-tagger-v2/siglip.msgpack +3 -0
app.py
CHANGED
@@ -14,10 +14,13 @@ from Models.CLIP import CLIP
|
|
14 |
|
15 |
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
|
16 |
pos = pos_img_embs + pos_tags_embs
|
|
|
17 |
|
18 |
neg = neg_img_embs + neg_tags_embs
|
|
|
19 |
|
20 |
result = pos - neg
|
|
|
21 |
return result
|
22 |
|
23 |
|
@@ -48,12 +51,9 @@ def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
|
|
48 |
|
49 |
class Predictor:
|
50 |
def __init__(self):
|
|
|
51 |
self.base_model = "wd-v1-4-convnext-tagger-v2"
|
52 |
|
53 |
-
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
|
54 |
-
data = f.read()
|
55 |
-
|
56 |
-
self.params = flax.serialization.msgpack_restore(data)["model"]
|
57 |
self.model = CLIP()
|
58 |
|
59 |
self.tags_df = pd.read_csv("data/selected_tags.csv")
|
@@ -64,12 +64,27 @@ class Predictor:
|
|
64 |
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
|
65 |
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def predict(
|
68 |
self,
|
69 |
pos_img_input,
|
70 |
neg_img_input,
|
71 |
positive_tags,
|
72 |
negative_tags,
|
|
|
73 |
selected_ratings,
|
74 |
n_neighbours,
|
75 |
api_username,
|
@@ -78,6 +93,8 @@ class Predictor:
|
|
78 |
tags_df = self.tags_df
|
79 |
model = self.model
|
80 |
|
|
|
|
|
81 |
num_classes = len(tags_df)
|
82 |
|
83 |
output_shape = model.out_units
|
@@ -172,10 +189,10 @@ def main():
|
|
172 |
positive_tags = gr.Textbox(label="Positive tags")
|
173 |
negative_tags = gr.Textbox(label="Negative tags")
|
174 |
with gr.Column():
|
175 |
-
|
176 |
-
choices=["
|
177 |
-
value=
|
178 |
-
label="
|
179 |
)
|
180 |
n_neighbours = gr.Slider(
|
181 |
minimum=1,
|
@@ -185,8 +202,14 @@ def main():
|
|
185 |
label="# of images",
|
186 |
)
|
187 |
with gr.Column():
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
find_btn = gr.Button("Find similar images")
|
192 |
|
@@ -199,6 +222,7 @@ def main():
|
|
199 |
None,
|
200 |
"marcille_donato",
|
201 |
"",
|
|
|
202 |
["General", "Sensitive"],
|
203 |
5,
|
204 |
"",
|
@@ -209,6 +233,7 @@ def main():
|
|
209 |
None,
|
210 |
"yellow_eyes,red_horns",
|
211 |
"",
|
|
|
212 |
["General", "Sensitive"],
|
213 |
5,
|
214 |
"",
|
@@ -219,6 +244,7 @@ def main():
|
|
219 |
None,
|
220 |
"artoria_pendragon_(fate),solo",
|
221 |
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
|
|
|
222 |
["General", "Sensitive"],
|
223 |
5,
|
224 |
"",
|
@@ -229,6 +255,7 @@ def main():
|
|
229 |
None,
|
230 |
"fujimaru_ritsuka_(female)",
|
231 |
"solo",
|
|
|
232 |
["General", "Sensitive"],
|
233 |
5,
|
234 |
"",
|
@@ -239,6 +266,7 @@ def main():
|
|
239 |
"examples/46657164_p1.jpg",
|
240 |
"",
|
241 |
"",
|
|
|
242 |
["General", "Sensitive"],
|
243 |
5,
|
244 |
"",
|
@@ -250,6 +278,7 @@ def main():
|
|
250 |
neg_img_input,
|
251 |
positive_tags,
|
252 |
negative_tags,
|
|
|
253 |
selected_ratings,
|
254 |
n_neighbours,
|
255 |
api_username,
|
@@ -268,6 +297,7 @@ def main():
|
|
268 |
neg_img_input,
|
269 |
positive_tags,
|
270 |
negative_tags,
|
|
|
271 |
selected_ratings,
|
272 |
n_neighbours,
|
273 |
api_username,
|
|
|
14 |
|
15 |
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
|
16 |
pos = pos_img_embs + pos_tags_embs
|
17 |
+
faiss.normalize_L2(pos)
|
18 |
|
19 |
neg = neg_img_embs + neg_tags_embs
|
20 |
+
faiss.normalize_L2(neg)
|
21 |
|
22 |
result = pos - neg
|
23 |
+
faiss.normalize_L2(result)
|
24 |
return result
|
25 |
|
26 |
|
|
|
51 |
|
52 |
class Predictor:
|
53 |
def __init__(self):
|
54 |
+
self.loaded_variant = None
|
55 |
self.base_model = "wd-v1-4-convnext-tagger-v2"
|
56 |
|
|
|
|
|
|
|
|
|
57 |
self.model = CLIP()
|
58 |
|
59 |
self.tags_df = pd.read_csv("data/selected_tags.csv")
|
|
|
64 |
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
|
65 |
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
|
66 |
|
67 |
+
def load_params(self, variant):
|
68 |
+
if self.loaded_variant == variant:
|
69 |
+
return
|
70 |
+
|
71 |
+
if variant == "CLIP":
|
72 |
+
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
|
73 |
+
data = f.read()
|
74 |
+
elif variant == "SigLIP":
|
75 |
+
with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f:
|
76 |
+
data = f.read()
|
77 |
+
|
78 |
+
self.params = flax.serialization.msgpack_restore(data)["model"]
|
79 |
+
self.loaded_variant = variant
|
80 |
+
|
81 |
def predict(
|
82 |
self,
|
83 |
pos_img_input,
|
84 |
neg_img_input,
|
85 |
positive_tags,
|
86 |
negative_tags,
|
87 |
+
selected_model,
|
88 |
selected_ratings,
|
89 |
n_neighbours,
|
90 |
api_username,
|
|
|
93 |
tags_df = self.tags_df
|
94 |
model = self.model
|
95 |
|
96 |
+
self.load_params(selected_model)
|
97 |
+
|
98 |
num_classes = len(tags_df)
|
99 |
|
100 |
output_shape = model.out_units
|
|
|
189 |
positive_tags = gr.Textbox(label="Positive tags")
|
190 |
negative_tags = gr.Textbox(label="Negative tags")
|
191 |
with gr.Column():
|
192 |
+
selected_model = gr.Radio(
|
193 |
+
choices=["CLIP", "SigLIP"],
|
194 |
+
value="CLIP",
|
195 |
+
label="Model",
|
196 |
)
|
197 |
n_neighbours = gr.Slider(
|
198 |
minimum=1,
|
|
|
202 |
label="# of images",
|
203 |
)
|
204 |
with gr.Column():
|
205 |
+
selected_ratings = gr.CheckboxGroup(
|
206 |
+
choices=["General", "Sensitive", "Questionable", "Explicit"],
|
207 |
+
value=["General", "Sensitive"],
|
208 |
+
label="Ratings",
|
209 |
+
)
|
210 |
+
with gr.Row():
|
211 |
+
api_username = gr.Textbox(label="Danbooru API Username")
|
212 |
+
api_key = gr.Textbox(label="Danbooru API Key")
|
213 |
|
214 |
find_btn = gr.Button("Find similar images")
|
215 |
|
|
|
222 |
None,
|
223 |
"marcille_donato",
|
224 |
"",
|
225 |
+
"CLIP",
|
226 |
["General", "Sensitive"],
|
227 |
5,
|
228 |
"",
|
|
|
233 |
None,
|
234 |
"yellow_eyes,red_horns",
|
235 |
"",
|
236 |
+
"CLIP",
|
237 |
["General", "Sensitive"],
|
238 |
5,
|
239 |
"",
|
|
|
244 |
None,
|
245 |
"artoria_pendragon_(fate),solo",
|
246 |
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
|
247 |
+
"CLIP",
|
248 |
["General", "Sensitive"],
|
249 |
5,
|
250 |
"",
|
|
|
255 |
None,
|
256 |
"fujimaru_ritsuka_(female)",
|
257 |
"solo",
|
258 |
+
"CLIP",
|
259 |
["General", "Sensitive"],
|
260 |
5,
|
261 |
"",
|
|
|
266 |
"examples/46657164_p1.jpg",
|
267 |
"",
|
268 |
"",
|
269 |
+
"CLIP",
|
270 |
["General", "Sensitive"],
|
271 |
5,
|
272 |
"",
|
|
|
278 |
neg_img_input,
|
279 |
positive_tags,
|
280 |
negative_tags,
|
281 |
+
selected_model,
|
282 |
selected_ratings,
|
283 |
n_neighbours,
|
284 |
api_username,
|
|
|
297 |
neg_img_input,
|
298 |
positive_tags,
|
299 |
negative_tags,
|
300 |
+
selected_model,
|
301 |
selected_ratings,
|
302 |
n_neighbours,
|
303 |
api_username,
|
data/wd-v1-4-convnext-tagger-v2/siglip.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b541d6ed39a4df5ca2edd7e3431e936bbb61c9499026ad3365361af13aa06d06
|
3 |
+
size 48689369
|