File size: 4,226 Bytes
23fa49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import json

import faiss
import flax
import gradio as gr
import jax
import numpy as np
import pandas as pd
import requests

from Models.CLIP import CLIP


def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
    headers = {"User-Agent": "image_similarity_tool"}
    ratings_to_letters = {
        "General": "g",
        "Sensitive": "s",
        "Questionable": "q",
        "Explicit": "e",
    }

    acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]

    image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
    if api_username != "" and api_key != "":
        image_url = f"{image_url}?api_key={api_key}&login={api_username}"

    r = requests.get(image_url, headers=headers)
    if r.status_code != 200:
        return None

    content = json.loads(r.text)
    image_url = content["large_file_url"] if "large_file_url" in content else None
    image_url = image_url if content["rating"] in acceptable_ratings else None
    return image_url


class Predictor:
    def __init__(self):
        self.base_model = "wd-v1-4-convnext-tagger-v2"

        with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
            data = f.read()

        self.params = flax.serialization.msgpack_restore(data)["model"]
        self.model = CLIP()

        self.tags_df = pd.read_csv("data/selected_tags.csv")

        self.images_ids = np.load("index/cosine_ids.npy")

        self.knn_index = faiss.read_index("index/cosine_knn.index")

        config = json.loads(open("index/cosine_infos.json").read())["index_param"]
        faiss.ParameterSpace().set_index_parameters(self.knn_index, config)

    def predict(self, positive_tags, negative_tags, n_neighbours=5):
        tags_df = self.tags_df
        model = self.model

        num_classes = len(tags_df)

        positive_tags = positive_tags.split(",")
        negative_tags = negative_tags.split(",")

        positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
        negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()

        tags = np.zeros((1, num_classes), dtype=np.float32)
        tags[0][positive_tags_idxs] = 1
        emb_from_logits = model.apply(
            {"params": self.params},
            tags,
            method=model.encode_text,
        )
        emb_from_logits = jax.device_get(emb_from_logits)

        if len(negative_tags_idxs) > 0:
            tags = np.zeros((1, num_classes), dtype=np.float32)
            tags[0][negative_tags_idxs] = 1

            neg_emb_from_logits = model.apply(
                {"params": self.params},
                tags,
                method=model.encode_text,
            )
            neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
            emb_from_logits = emb_from_logits - neg_emb_from_logits

        faiss.normalize_L2(emb_from_logits)

        dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
        neighbours_ids = self.images_ids[indexes][0]
        neighbours_ids = [int(x) for x in neighbours_ids]

        captions = []
        image_urls = []
        for image_id, dist in zip(neighbours_ids, dists[0]):
            current_url = danbooru_id_to_url(
                image_id,
                [
                    "General",
                    "Sensitive",
                    "Questionable",
                    "Explicit",
                ],
            )
            if current_url is not None:
                image_urls.append(current_url)
                captions.append(f"{image_id}/{dist:.2f}")
        return list(zip(image_urls, captions))


def main():
    predictor = Predictor()

    with gr.Blocks() as demo:
        with gr.Row():
            positive_tags = gr.Textbox(label="Positive tags")
            negative_tags = gr.Textbox(label="Negative tags")

        find_btn = gr.Button("Find similar images")

        similar_images = gr.Gallery(label="Similar images", columns=[5])

        find_btn.click(
            fn=predictor.predict,
            inputs=[positive_tags, negative_tags],
            outputs=[similar_images],
        )

    demo.queue()
    demo.launch()


if __name__ == "__main__":
    main()