import json import faiss import flax import gradio as gr import jax import numpy as np import pandas as pd import requests from imgutils.tagging import wd14 from Models.CLIP import CLIP def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs): pos = pos_img_embs + pos_tags_embs neg = neg_img_embs + neg_tags_embs result = pos - neg return result def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): headers = {"User-Agent": "image_similarity_tool"} ratings_to_letters = { "General": "g", "Sensitive": "s", "Questionable": "q", "Explicit": "e", } acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" if api_username != "" and api_key != "": image_url = f"{image_url}?api_key={api_key}&login={api_username}" r = requests.get(image_url, headers=headers) if r.status_code != 200: return None content = json.loads(r.text) image_url = content["large_file_url"] if "large_file_url" in content else None image_url = image_url if content["rating"] in acceptable_ratings else None return image_url class Predictor: def __init__(self): self.base_model = "wd-v1-4-convnext-tagger-v2" with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: data = f.read() self.params = flax.serialization.msgpack_restore(data)["model"] self.model = CLIP() self.tags_df = pd.read_csv("data/selected_tags.csv") self.images_ids = np.load("index/cosine_ids.npy") self.knn_index = faiss.read_index("index/cosine_knn.index") config = json.loads(open("index/cosine_infos.json").read())["index_param"] faiss.ParameterSpace().set_index_parameters(self.knn_index, config) def predict( self, pos_img_input, neg_img_input, positive_tags, negative_tags, selected_ratings, n_neighbours, api_username, api_key, ): tags_df = self.tags_df model = self.model num_classes = len(tags_df) output_shape = model.out_units pos_img_embs = np.zeros((1, output_shape), dtype=np.float32) neg_img_embs = np.zeros((1, output_shape), dtype=np.float32) pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32) neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32) positive_tags = positive_tags.split(",") negative_tags = negative_tags.split(",") positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() if pos_img_input is not None: pos_img_embs = wd14.get_wd14_tags( pos_img_input, model_name="ConvNext", fmt=("embedding"), ) pos_img_embs = np.expand_dims(pos_img_embs, 0) faiss.normalize_L2(pos_img_embs) if neg_img_input is not None: neg_img_embs = wd14.get_wd14_tags( neg_img_input, model_name="ConvNext", fmt=("embedding"), ) neg_img_embs = np.expand_dims(neg_img_embs, 0) faiss.normalize_L2(neg_img_embs) if len(positive_tags_idxs) > 0: tags = np.zeros((1, num_classes), dtype=np.float32) tags[0][positive_tags_idxs] = 1 pos_tags_embs = model.apply( {"params": self.params}, tags, method=model.encode_text, ) pos_tags_embs = jax.device_get(pos_tags_embs) faiss.normalize_L2(pos_tags_embs) if len(negative_tags_idxs) > 0: tags = np.zeros((1, num_classes), dtype=np.float32) tags[0][negative_tags_idxs] = 1 neg_tags_embs = model.apply( {"params": self.params}, tags, method=model.encode_text, ) neg_tags_embs = jax.device_get(neg_tags_embs) faiss.normalize_L2(neg_tags_embs) embeddings = combine_embeddings( pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs, ) dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) neighbours_ids = self.images_ids[indexes][0] neighbours_ids = [int(x) for x in neighbours_ids] captions = [] image_urls = [] for image_id, dist in zip(neighbours_ids, dists[0]): current_url = danbooru_id_to_url( image_id, selected_ratings, api_username, api_key, ) if current_url is not None: image_urls.append(current_url) captions.append(f"{image_id}/{dist:.2f}") return list(zip(image_urls, captions)) def main(): predictor = Predictor() with gr.Blocks() as demo: with gr.Row(): pos_img_input = gr.Image(type="pil", label="Positive input") neg_img_input = gr.Image(type="pil", label="Negative input") with gr.Row(): with gr.Column(): positive_tags = gr.Textbox(label="Positive tags") negative_tags = gr.Textbox(label="Negative tags") with gr.Column(): selected_ratings = gr.CheckboxGroup( choices=["General", "Sensitive", "Questionable", "Explicit"], value=["General", "Sensitive"], label="Ratings", ) n_neighbours = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="# of images", ) with gr.Column(): api_username = gr.Textbox(label="Danbooru API Username") api_key = gr.Textbox(label="Danbooru API Key") find_btn = gr.Button("Find similar images") similar_images = gr.Gallery(label="Similar images", columns=[5]) examples = gr.Examples( [ [ None, None, "marcille_donato", "", ["General", "Sensitive"], 5, "", "", ], [ None, None, "yellow_eyes,red_horns", "", ["General", "Sensitive"], 5, "", "", ], [ None, None, "artoria_pendragon_(fate),solo", "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair", ["General", "Sensitive"], 5, "", "", ], [ "examples/60378883_p0.jpg", None, "fujimaru_ritsuka_(female)", "solo", ["General", "Sensitive"], 5, "", "", ], [ "examples/DaRlExxUwAAcUOS-orig.jpg", "examples/46657164_p1.jpg", "", "", ["General", "Sensitive"], 5, "", "", ], ], inputs=[ pos_img_input, neg_img_input, positive_tags, negative_tags, selected_ratings, n_neighbours, api_username, api_key, ], outputs=[similar_images], fn=predictor.predict, run_on_click=True, cache_examples=False, ) find_btn.click( fn=predictor.predict, inputs=[ pos_img_input, neg_img_input, positive_tags, negative_tags, selected_ratings, n_neighbours, api_username, api_key, ], outputs=[similar_images], ) demo.queue() demo.launch() if __name__ == "__main__": main()