import json import faiss import flax import gradio as gr import jax import numpy as np import pandas as pd import requests from Models.CLIP import CLIP def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): headers = {"User-Agent": "image_similarity_tool"} ratings_to_letters = { "General": "g", "Sensitive": "s", "Questionable": "q", "Explicit": "e", } acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" if api_username != "" and api_key != "": image_url = f"{image_url}?api_key={api_key}&login={api_username}" r = requests.get(image_url, headers=headers) if r.status_code != 200: return None content = json.loads(r.text) image_url = content["large_file_url"] if "large_file_url" in content else None image_url = image_url if content["rating"] in acceptable_ratings else None return image_url class Predictor: def __init__(self): self.base_model = "wd-v1-4-convnext-tagger-v2" with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: data = f.read() self.params = flax.serialization.msgpack_restore(data)["model"] self.model = CLIP() self.tags_df = pd.read_csv("data/selected_tags.csv") self.images_ids = np.load("index/cosine_ids.npy") self.knn_index = faiss.read_index("index/cosine_knn.index") config = json.loads(open("index/cosine_infos.json").read())["index_param"] faiss.ParameterSpace().set_index_parameters(self.knn_index, config) def predict(self, positive_tags, negative_tags, n_neighbours=5): tags_df = self.tags_df model = self.model num_classes = len(tags_df) positive_tags = positive_tags.split(",") negative_tags = negative_tags.split(",") positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() tags = np.zeros((1, num_classes), dtype=np.float32) tags[0][positive_tags_idxs] = 1 emb_from_logits = model.apply( {"params": self.params}, tags, method=model.encode_text, ) emb_from_logits = jax.device_get(emb_from_logits) if len(negative_tags_idxs) > 0: tags = np.zeros((1, num_classes), dtype=np.float32) tags[0][negative_tags_idxs] = 1 neg_emb_from_logits = model.apply( {"params": self.params}, tags, method=model.encode_text, ) neg_emb_from_logits = jax.device_get(neg_emb_from_logits) emb_from_logits = emb_from_logits - neg_emb_from_logits faiss.normalize_L2(emb_from_logits) dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours) neighbours_ids = self.images_ids[indexes][0] neighbours_ids = [int(x) for x in neighbours_ids] captions = [] image_urls = [] for image_id, dist in zip(neighbours_ids, dists[0]): current_url = danbooru_id_to_url( image_id, [ "General", "Sensitive", "Questionable", "Explicit", ], ) if current_url is not None: image_urls.append(current_url) captions.append(f"{image_id}/{dist:.2f}") return list(zip(image_urls, captions)) def main(): predictor = Predictor() with gr.Blocks() as demo: with gr.Row(): positive_tags = gr.Textbox(label="Positive tags") negative_tags = gr.Textbox(label="Negative tags") find_btn = gr.Button("Find similar images") similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=predictor.predict, inputs=[positive_tags, negative_tags], outputs=[similar_images], ) demo.queue() demo.launch() if __name__ == "__main__": main()