import argparse import functools import json from pathlib import Path import faiss import gradio as gr import numpy as np import PIL.Image import requests import tensorflow as tf from huggingface_hub import hf_hub_download from Utils import dbimutils TITLE = "## Danbooru Explorer" DESCRIPTION = """ Image similarity-based retrieval tool using: - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) as feature extractor - [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing """ CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV_MODEL_REVISION = "v2.0" CONV_FEXT_LAYER = "predictions_norm" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true") return parser.parse_args() def download_model(model_repo, model_revision): model_files = [ {"filename": "saved_model.pb", "subfolder": ""}, {"filename": "keras_metadata.pb", "subfolder": ""}, {"filename": "variables.index", "subfolder": "variables"}, {"filename": "variables.data-00000-of-00001", "subfolder": "variables"}, ] model_file_paths = [] for elem in model_files: model_file_paths.append( Path(hf_hub_download(model_repo, revision=model_revision, **elem)) ) model_path = model_file_paths[0].parents[0] return model_path def load_model(model_repo, model_revision, feature_extraction_layer): model_path = download_model(model_repo, model_revision) full_model = tf.keras.models.load_model(model_path) model = tf.keras.models.Model( full_model.inputs, full_model.get_layer(feature_extraction_layer).output ) return model def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): headers = {"User-Agent": "image_similarity_tool"} ratings_to_letters = { "General": "g", "Sensitive": "s", "Questionable": "q", "Explicit": "e", } acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" if api_username != "" and api_key != "": image_url = f"{image_url}?api_key={api_key}&login={api_username}" r = requests.get(image_url, headers=headers) if r.status_code != 200: return None content = json.loads(r.text) image_url = content["large_file_url"] if "large_file_url" in content else None image_url = image_url if content["rating"] in acceptable_ratings else None return image_url class SimilaritySearcher: def __init__(self, model, images_ids): self.knn_index = None self.knn_metric = None self.model = model self.images_ids = images_ids def change_index(self, knn_metric): if knn_metric == self.knn_metric: return if knn_metric == "ip": self.knn_index = faiss.read_index("index/ip_knn.index") config = json.loads(open("index/ip_infos.json").read())["index_param"] elif knn_metric == "cosine": self.knn_index = faiss.read_index("index/cosine_knn.index") config = json.loads(open("index/cosine_infos.json").read())["index_param"] faiss.ParameterSpace().set_index_parameters(self.knn_index, config) self.knn_metric = knn_metric def predict( self, image, selected_ratings, knn_metric, api_username, api_key, n_neighbours ): _, height, width, _ = self.model.inputs[0].shape self.change_index(knn_metric) # Alpha to white image = image.convert("RGBA") new_image = PIL.Image.new("RGBA", image.size, "WHITE") new_image.paste(image, mask=image) image = new_image.convert("RGB") image = np.asarray(image) # PIL RGB to OpenCV BGR image = image[:, :, ::-1] image = dbimutils.make_square(image, height) image = dbimutils.smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) target = self.model(image).numpy() if self.knn_metric == "cosine": faiss.normalize_L2(target) dists, indexes = self.knn_index.search(target, k=n_neighbours) neighbours_ids = self.images_ids[indexes][0] neighbours_ids = [int(x) for x in neighbours_ids] captions = [] for image_id, dist in zip(neighbours_ids, dists[0]): captions.append(f"{image_id}/{dist:.2f}") image_urls = [] for image_id in neighbours_ids: current_url = danbooru_id_to_url( image_id, selected_ratings, api_username, api_key ) if current_url is not None: image_urls.append(current_url) return list(zip(image_urls, captions)) def main(): args = parse_args() model = load_model(CONV_MODEL_REPO, CONV_MODEL_REVISION, CONV_FEXT_LAYER) images_ids = np.load("index/cosine_ids.npy") searcher = SimilaritySearcher(model=model, images_ids=images_ids) with gr.Blocks() as demo: gr.Markdown(TITLE) gr.Markdown(DESCRIPTION) with gr.Row(): input = gr.Image(type="pil", label="Input") with gr.Column(): with gr.Row(): api_username = gr.Textbox(label="Danbooru API Username") api_key = gr.Textbox(label="Danbooru API Key") with gr.Row(): selected_ratings = gr.CheckboxGroup( choices=["General", "Sensitive", "Questionable", "Explicit"], value=["General", "Sensitive"], label="Ratings", ) selected_metric = gr.Radio( choices=["cosine"], value="cosine", label="Metric selection", visible=False, ) n_neighbours = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="# of images" ) find_btn = gr.Button("Find similar images") similar_images = gr.Gallery(label="Similar images") similar_images.style(grid=5) find_btn.click( fn=searcher.predict, inputs=[ input, selected_ratings, selected_metric, api_username, api_key, n_neighbours, ], outputs=[similar_images], ) demo.queue() demo.launch(share=args.share) if __name__ == "__main__": main()