SmilingWolf commited on
Commit
23fa49c
1 Parent(s): 488df98

First commit

Browse files
README.md CHANGED
@@ -9,5 +9,3 @@ app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import faiss
4
+ import flax
5
+ import gradio as gr
6
+ import jax
7
+ import numpy as np
8
+ import pandas as pd
9
+ import requests
10
+
11
+ from Models.CLIP import CLIP
12
+
13
+
14
+ def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
15
+ headers = {"User-Agent": "image_similarity_tool"}
16
+ ratings_to_letters = {
17
+ "General": "g",
18
+ "Sensitive": "s",
19
+ "Questionable": "q",
20
+ "Explicit": "e",
21
+ }
22
+
23
+ acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
24
+
25
+ image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
26
+ if api_username != "" and api_key != "":
27
+ image_url = f"{image_url}?api_key={api_key}&login={api_username}"
28
+
29
+ r = requests.get(image_url, headers=headers)
30
+ if r.status_code != 200:
31
+ return None
32
+
33
+ content = json.loads(r.text)
34
+ image_url = content["large_file_url"] if "large_file_url" in content else None
35
+ image_url = image_url if content["rating"] in acceptable_ratings else None
36
+ return image_url
37
+
38
+
39
+ class Predictor:
40
+ def __init__(self):
41
+ self.base_model = "wd-v1-4-convnext-tagger-v2"
42
+
43
+ with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
44
+ data = f.read()
45
+
46
+ self.params = flax.serialization.msgpack_restore(data)["model"]
47
+ self.model = CLIP()
48
+
49
+ self.tags_df = pd.read_csv("data/selected_tags.csv")
50
+
51
+ self.images_ids = np.load("index/cosine_ids.npy")
52
+
53
+ self.knn_index = faiss.read_index("index/cosine_knn.index")
54
+
55
+ config = json.loads(open("index/cosine_infos.json").read())["index_param"]
56
+ faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
57
+
58
+ def predict(self, positive_tags, negative_tags, n_neighbours=5):
59
+ tags_df = self.tags_df
60
+ model = self.model
61
+
62
+ num_classes = len(tags_df)
63
+
64
+ positive_tags = positive_tags.split(",")
65
+ negative_tags = negative_tags.split(",")
66
+
67
+ positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
68
+ negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
69
+
70
+ tags = np.zeros((1, num_classes), dtype=np.float32)
71
+ tags[0][positive_tags_idxs] = 1
72
+ emb_from_logits = model.apply(
73
+ {"params": self.params},
74
+ tags,
75
+ method=model.encode_text,
76
+ )
77
+ emb_from_logits = jax.device_get(emb_from_logits)
78
+
79
+ if len(negative_tags_idxs) > 0:
80
+ tags = np.zeros((1, num_classes), dtype=np.float32)
81
+ tags[0][negative_tags_idxs] = 1
82
+
83
+ neg_emb_from_logits = model.apply(
84
+ {"params": self.params},
85
+ tags,
86
+ method=model.encode_text,
87
+ )
88
+ neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
89
+ emb_from_logits = emb_from_logits - neg_emb_from_logits
90
+
91
+ faiss.normalize_L2(emb_from_logits)
92
+
93
+ dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
94
+ neighbours_ids = self.images_ids[indexes][0]
95
+ neighbours_ids = [int(x) for x in neighbours_ids]
96
+
97
+ captions = []
98
+ image_urls = []
99
+ for image_id, dist in zip(neighbours_ids, dists[0]):
100
+ current_url = danbooru_id_to_url(
101
+ image_id,
102
+ [
103
+ "General",
104
+ "Sensitive",
105
+ "Questionable",
106
+ "Explicit",
107
+ ],
108
+ )
109
+ if current_url is not None:
110
+ image_urls.append(current_url)
111
+ captions.append(f"{image_id}/{dist:.2f}")
112
+ return list(zip(image_urls, captions))
113
+
114
+
115
+ def main():
116
+ predictor = Predictor()
117
+
118
+ with gr.Blocks() as demo:
119
+ with gr.Row():
120
+ positive_tags = gr.Textbox(label="Positive tags")
121
+ negative_tags = gr.Textbox(label="Negative tags")
122
+
123
+ find_btn = gr.Button("Find similar images")
124
+
125
+ similar_images = gr.Gallery(label="Similar images", columns=[5])
126
+
127
+ find_btn.click(
128
+ fn=predictor.predict,
129
+ inputs=[positive_tags, negative_tags],
130
+ outputs=[similar_images],
131
+ )
132
+
133
+ demo.queue()
134
+ demo.launch()
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
data/wd-v1-4-convnext-tagger-v2/clip.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3be3b97824313f01d9f1d74c43e441199b7ea485f5698d2008739f34c3e41200
3
+ size 48689306
index/cosine_ids.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df724519c8c1981e49d80e2430261deb4fb6edf6d9c04e134427879710747394
3
+ size 21830676
index/cosine_infos.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"index_key": "OPQ256_1280,IVF16384_HNSW32,PQ256x8", "index_param": "nprobe=16,efSearch=32,ht=2048", "index_path": "/home/SmilingWolf/eval/index/ConvNextBV1_01_14_2023_08h37m46s_cosine_knn.index", "size in bytes": 1535843672, "avg_search_speed_ms": 10.164478485783887, "99p_search_speed_ms": 12.419190758373587, "reconstruction error %": 22.007358074188232, "nb vectors": 5457637, "vectors dimension": 1024, "compression ratio": 14.555180035276402}
index/cosine_knn.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a718ab8370df8b9d84002c55f945ef241e4cc3450d306c2ecd97661f51022ad
3
+ size 1535843672
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ faiss
2
+ jax[cpu]
3
+ flax