SmilingWolf commited on
Commit
9934ab5
1 Parent(s): 0a8bf1b

Normalize all the things. Also add an example.

Browse files
Files changed (1) hide show
  1. app.py +26 -2
app.py CHANGED
@@ -83,6 +83,7 @@ class Predictor:
83
  method=model.encode_text,
84
  )
85
  emb_from_logits = jax.device_get(emb_from_logits)
 
86
 
87
  if len(negative_tags_idxs) > 0:
88
  tags = np.zeros((1, num_classes), dtype=np.float32)
@@ -94,9 +95,10 @@ class Predictor:
94
  method=model.encode_text,
95
  )
96
  neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
97
- emb_from_logits = emb_from_logits - neg_emb_from_logits
98
 
99
- faiss.normalize_L2(emb_from_logits)
 
100
 
101
  dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
102
  neighbours_ids = self.images_ids[indexes][0]
@@ -145,6 +147,28 @@ def main():
145
 
146
  similar_images = gr.Gallery(label="Similar images", columns=[5])
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  find_btn.click(
149
  fn=predictor.predict,
150
  inputs=[
 
83
  method=model.encode_text,
84
  )
85
  emb_from_logits = jax.device_get(emb_from_logits)
86
+ faiss.normalize_L2(emb_from_logits)
87
 
88
  if len(negative_tags_idxs) > 0:
89
  tags = np.zeros((1, num_classes), dtype=np.float32)
 
95
  method=model.encode_text,
96
  )
97
  neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
98
+ faiss.normalize_L2(neg_emb_from_logits)
99
 
100
+ emb_from_logits = emb_from_logits - neg_emb_from_logits
101
+ faiss.normalize_L2(emb_from_logits)
102
 
103
  dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
104
  neighbours_ids = self.images_ids[indexes][0]
 
147
 
148
  similar_images = gr.Gallery(label="Similar images", columns=[5])
149
 
150
+ examples = gr.Examples(
151
+ [
152
+ [
153
+ "artoria_pendragon_(fate),solo",
154
+ "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
155
+ ["General", "Sensitive"],
156
+ 5,
157
+ "",
158
+ "",
159
+ ]
160
+ ],
161
+ inputs=[
162
+ positive_tags,
163
+ negative_tags,
164
+ selected_ratings,
165
+ n_neighbours,
166
+ api_username,
167
+ api_key,
168
+ ],
169
+ outputs=[similar_images],
170
+ )
171
+
172
  find_btn.click(
173
  fn=predictor.predict,
174
  inputs=[