import json import faiss import flax import gradio as gr import jax import numpy as np import pandas as pd import requests from imgutils.tagging import wd14 from Models.CLIP import CLIP def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs): pos = pos_img_embs + pos_tags_embs faiss.normalize_L2(pos) neg = neg_img_embs + neg_tags_embs faiss.normalize_L2(neg) result = pos - neg faiss.normalize_L2(result) return result def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): headers = {"User-Agent": "image_similarity_tool"} ratings_to_letters = { "General": "g", "Sensitive": "s", "Questionable": "q", "Explicit": "e", } acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" if api_username != "" and api_key != "": image_url = f"{image_url}?api_key={api_key}&login={api_username}" r = requests.get(image_url, headers=headers) if r.status_code != 200: return None content = json.loads(r.text) image_url = content["large_file_url"] if "large_file_url" in content else None image_url = image_url if content["rating"] in acceptable_ratings else None return image_url class Predictor: def __init__(self): self.loaded_variant = None self.base_model = "wd-v1-4-convnext-tagger-v2" self.model = CLIP() self.tags_df = pd.read_csv("data/selected_tags.csv") self.images_ids = np.load("index/cosine_ids.npy") self.knn_index = faiss.read_index("index/cosine_knn.index") config = json.loads(open("index/cosine_infos.json").read())["index_param"] faiss.ParameterSpace().set_index_parameters(self.knn_index, config) def load_params(self, variant): if self.loaded_variant == variant: return if variant == "CLIP": with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: data = f.read() elif variant == "SigLIP": with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f: data = f.read() self.params = flax.serialization.msgpack_restore(data)["model"] self.loaded_variant = variant def predict( self, pos_img_input, neg_img_input, positive_tags, negative_tags, selected_model, selected_ratings, n_neighbours, api_username, api_key, ): tags_df = self.tags_df model = self.model self.load_params(selected_model) num_classes = len(tags_df) output_shape = model.out_units pos_img_embs = np.zeros((1, output_shape), dtype=np.float32) neg_img_embs = np.zeros((1, output_shape), dtype=np.float32) pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32) neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32) positive_tags = positive_tags.split(",") negative_tags = negative_tags.split(",") positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() if pos_img_input is not None: pos_img_embs = wd14.get_wd14_tags( pos_img_input, model_name="ConvNext", fmt=("embedding"), ) pos_img_embs = np.expand_dims(pos_img_embs, 0) faiss.normalize_L2(pos_img_embs) if neg_img_input is not None: neg_img_embs = wd14.get_wd14_tags( neg_img_input, model_name="ConvNext", fmt=("embedding"), ) neg_img_embs = np.expand_dims(neg_img_embs, 0) faiss.normalize_L2(neg_img_embs) if len(positive_tags_idxs) > 0: tags = np.zeros((1, num_classes), dtype=np.float32) tags[0][positive_tags_idxs] = 1 pos_tags_embs = model.apply( {"params": self.params}, tags, method=model.encode_text, ) pos_tags_embs = jax.device_get(pos_tags_embs) faiss.normalize_L2(pos_tags_embs) if len(negative_tags_idxs) > 0: tags = np.zeros((1, num_classes), dtype=np.float32) tags[0][negative_tags_idxs] = 1 neg_tags_embs = model.apply( {"params": self.params}, tags, method=model.encode_text, ) neg_tags_embs = jax.device_get(neg_tags_embs) faiss.normalize_L2(neg_tags_embs) embeddings = combine_embeddings( pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs, ) dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) neighbours_ids = self.images_ids[indexes][0] neighbours_ids = [int(x) for x in neighbours_ids] captions = [] image_urls = [] for image_id, dist in zip(neighbours_ids, dists[0]): current_url = danbooru_id_to_url( image_id, selected_ratings, api_username, api_key, ) if current_url is not None: image_urls.append(current_url) captions.append(f"{image_id}/{dist:.2f}") return list(zip(image_urls, captions)) def main(): predictor = Predictor() with gr.Blocks() as demo: with gr.Row(): pos_img_input = gr.Image(type="pil", label="Positive input") neg_img_input = gr.Image(type="pil", label="Negative input") with gr.Row(): with gr.Column(): positive_tags = gr.Textbox(label="Positive tags") negative_tags = gr.Textbox(label="Negative tags") with gr.Column(): selected_model = gr.Radio( choices=["CLIP", "SigLIP"], value="CLIP", label="Model", ) n_neighbours = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="# of images", ) with gr.Column(): selected_ratings = gr.CheckboxGroup( choices=["General", "Sensitive", "Questionable", "Explicit"], value=["General", "Sensitive"], label="Ratings", ) with gr.Row(): api_username = gr.Textbox(label="Danbooru API Username") api_key = gr.Textbox(label="Danbooru API Key") find_btn = gr.Button("Find similar images") similar_images = gr.Gallery(label="Similar images", columns=[5]) examples = gr.Examples( [ [ None, None, "marcille_donato", "", "CLIP", ["General", "Sensitive"], 5, "", "", ], [ None, None, "yellow_eyes,red_horns", "", "CLIP", ["General", "Sensitive"], 5, "", "", ], [ None, None, "artoria_pendragon_(fate),solo", "green_eyes", "CLIP", ["General", "Sensitive"], 5, "", "", ], [ "examples/60378883_p0.jpg", None, "fujimaru_ritsuka_(female)", "solo", "CLIP", ["General", "Sensitive"], 5, "", "", ], [ "examples/DaRlExxUwAAcUOS-orig.jpg", "examples/46657164_p1.jpg", "", "", "CLIP", ["General", "Sensitive"], 5, "", "", ], ], inputs=[ pos_img_input, neg_img_input, positive_tags, negative_tags, selected_model, selected_ratings, n_neighbours, api_username, api_key, ], outputs=[similar_images], fn=predictor.predict, run_on_click=True, cache_examples=False, ) find_btn.click( fn=predictor.predict, inputs=[ pos_img_input, neg_img_input, positive_tags, negative_tags, selected_model, selected_ratings, n_neighbours, api_username, api_key, ], outputs=[similar_images], ) demo.queue() demo.launch() if __name__ == "__main__": main()