|
import json |
|
|
|
import faiss |
|
import flax |
|
import gradio as gr |
|
import jax |
|
import numpy as np |
|
import pandas as pd |
|
import requests |
|
|
|
from Models.CLIP import CLIP |
|
|
|
|
|
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): |
|
headers = {"User-Agent": "image_similarity_tool"} |
|
ratings_to_letters = { |
|
"General": "g", |
|
"Sensitive": "s", |
|
"Questionable": "q", |
|
"Explicit": "e", |
|
} |
|
|
|
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] |
|
|
|
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" |
|
if api_username != "" and api_key != "": |
|
image_url = f"{image_url}?api_key={api_key}&login={api_username}" |
|
|
|
r = requests.get(image_url, headers=headers) |
|
if r.status_code != 200: |
|
return None |
|
|
|
content = json.loads(r.text) |
|
image_url = content["large_file_url"] if "large_file_url" in content else None |
|
image_url = image_url if content["rating"] in acceptable_ratings else None |
|
return image_url |
|
|
|
|
|
class Predictor: |
|
def __init__(self): |
|
self.base_model = "wd-v1-4-convnext-tagger-v2" |
|
|
|
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: |
|
data = f.read() |
|
|
|
self.params = flax.serialization.msgpack_restore(data)["model"] |
|
self.model = CLIP() |
|
|
|
self.tags_df = pd.read_csv("data/selected_tags.csv") |
|
|
|
self.images_ids = np.load("index/cosine_ids.npy") |
|
|
|
self.knn_index = faiss.read_index("index/cosine_knn.index") |
|
|
|
config = json.loads(open("index/cosine_infos.json").read())["index_param"] |
|
faiss.ParameterSpace().set_index_parameters(self.knn_index, config) |
|
|
|
def predict(self, positive_tags, negative_tags, n_neighbours=5): |
|
tags_df = self.tags_df |
|
model = self.model |
|
|
|
num_classes = len(tags_df) |
|
|
|
positive_tags = positive_tags.split(",") |
|
negative_tags = negative_tags.split(",") |
|
|
|
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() |
|
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() |
|
|
|
tags = np.zeros((1, num_classes), dtype=np.float32) |
|
tags[0][positive_tags_idxs] = 1 |
|
emb_from_logits = model.apply( |
|
{"params": self.params}, |
|
tags, |
|
method=model.encode_text, |
|
) |
|
emb_from_logits = jax.device_get(emb_from_logits) |
|
|
|
if len(negative_tags_idxs) > 0: |
|
tags = np.zeros((1, num_classes), dtype=np.float32) |
|
tags[0][negative_tags_idxs] = 1 |
|
|
|
neg_emb_from_logits = model.apply( |
|
{"params": self.params}, |
|
tags, |
|
method=model.encode_text, |
|
) |
|
neg_emb_from_logits = jax.device_get(neg_emb_from_logits) |
|
emb_from_logits = emb_from_logits - neg_emb_from_logits |
|
|
|
faiss.normalize_L2(emb_from_logits) |
|
|
|
dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours) |
|
neighbours_ids = self.images_ids[indexes][0] |
|
neighbours_ids = [int(x) for x in neighbours_ids] |
|
|
|
captions = [] |
|
image_urls = [] |
|
for image_id, dist in zip(neighbours_ids, dists[0]): |
|
current_url = danbooru_id_to_url( |
|
image_id, |
|
[ |
|
"General", |
|
"Sensitive", |
|
"Questionable", |
|
"Explicit", |
|
], |
|
) |
|
if current_url is not None: |
|
image_urls.append(current_url) |
|
captions.append(f"{image_id}/{dist:.2f}") |
|
return list(zip(image_urls, captions)) |
|
|
|
|
|
def main(): |
|
predictor = Predictor() |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
positive_tags = gr.Textbox(label="Positive tags") |
|
negative_tags = gr.Textbox(label="Negative tags") |
|
|
|
find_btn = gr.Button("Find similar images") |
|
|
|
similar_images = gr.Gallery(label="Similar images", columns=[5]) |
|
|
|
find_btn.click( |
|
fn=predictor.predict, |
|
inputs=[positive_tags, negative_tags], |
|
outputs=[similar_images], |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|