|
import argparse |
|
import functools |
|
import json |
|
from pathlib import Path |
|
|
|
import faiss |
|
import gradio as gr |
|
import numpy as np |
|
import PIL.Image |
|
import requests |
|
import tensorflow as tf |
|
from huggingface_hub import hf_hub_download |
|
|
|
from Utils import dbimutils |
|
|
|
TITLE = "## Danbooru Explorer" |
|
DESCRIPTION = """ |
|
Image similarity-based retrieval tool using: |
|
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) as feature extractor |
|
- [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing |
|
""" |
|
|
|
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" |
|
CONV_MODEL_REVISION = "v2.0" |
|
CONV_FEXT_LAYER = "predictions_norm" |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--share", action="store_true") |
|
return parser.parse_args() |
|
|
|
|
|
def download_model(model_repo, model_revision): |
|
model_files = [ |
|
{"filename": "saved_model.pb", "subfolder": ""}, |
|
{"filename": "keras_metadata.pb", "subfolder": ""}, |
|
{"filename": "variables.index", "subfolder": "variables"}, |
|
{"filename": "variables.data-00000-of-00001", "subfolder": "variables"}, |
|
] |
|
|
|
model_file_paths = [] |
|
for elem in model_files: |
|
model_file_paths.append( |
|
Path(hf_hub_download(model_repo, revision=model_revision, **elem)) |
|
) |
|
|
|
model_path = model_file_paths[0].parents[0] |
|
return model_path |
|
|
|
|
|
def load_model(model_repo, model_revision, feature_extraction_layer): |
|
model_path = download_model(model_repo, model_revision) |
|
full_model = tf.keras.models.load_model(model_path) |
|
model = tf.keras.models.Model( |
|
full_model.inputs, full_model.get_layer(feature_extraction_layer).output |
|
) |
|
return model |
|
|
|
|
|
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): |
|
headers = {"User-Agent": "image_similarity_tool"} |
|
ratings_to_letters = { |
|
"General": "g", |
|
"Sensitive": "s", |
|
"Questionable": "q", |
|
"Explicit": "e", |
|
} |
|
|
|
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] |
|
|
|
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" |
|
if api_username != "" and api_key != "": |
|
image_url = f"{image_url}?api_key={api_key}&login={api_username}" |
|
|
|
r = requests.get(image_url, headers=headers) |
|
if r.status_code != 200: |
|
return None |
|
|
|
content = json.loads(r.text) |
|
image_url = content["large_file_url"] if "large_file_url" in content else None |
|
image_url = image_url if content["rating"] in acceptable_ratings else None |
|
return image_url |
|
|
|
|
|
class SimilaritySearcher: |
|
def __init__(self, model, images_ids): |
|
self.knn_index = None |
|
self.knn_metric = None |
|
|
|
self.model = model |
|
self.images_ids = images_ids |
|
|
|
def change_index(self, knn_metric): |
|
if knn_metric == self.knn_metric: |
|
return |
|
|
|
if knn_metric == "ip": |
|
self.knn_index = faiss.read_index("index/ip_knn.index") |
|
config = json.loads(open("index/ip_infos.json").read())["index_param"] |
|
elif knn_metric == "cosine": |
|
self.knn_index = faiss.read_index("index/cosine_knn.index") |
|
config = json.loads(open("index/cosine_infos.json").read())["index_param"] |
|
|
|
faiss.ParameterSpace().set_index_parameters(self.knn_index, config) |
|
self.knn_metric = knn_metric |
|
|
|
def predict( |
|
self, image, selected_ratings, knn_metric, api_username, api_key, n_neighbours |
|
): |
|
_, height, width, _ = self.model.inputs[0].shape |
|
|
|
self.change_index(knn_metric) |
|
|
|
|
|
image = image.convert("RGBA") |
|
new_image = PIL.Image.new("RGBA", image.size, "WHITE") |
|
new_image.paste(image, mask=image) |
|
image = new_image.convert("RGB") |
|
image = np.asarray(image) |
|
|
|
|
|
image = image[:, :, ::-1] |
|
|
|
image = dbimutils.make_square(image, height) |
|
image = dbimutils.smart_resize(image, height) |
|
image = image.astype(np.float32) |
|
image = np.expand_dims(image, 0) |
|
target = self.model(image).numpy() |
|
|
|
if self.knn_metric == "cosine": |
|
faiss.normalize_L2(target) |
|
|
|
dists, indexes = self.knn_index.search(target, k=n_neighbours) |
|
neighbours_ids = self.images_ids[indexes][0] |
|
neighbours_ids = [int(x) for x in neighbours_ids] |
|
|
|
captions = [] |
|
for image_id, dist in zip(neighbours_ids, dists[0]): |
|
captions.append(f"{image_id}/{dist:.2f}") |
|
|
|
image_urls = [] |
|
for image_id in neighbours_ids: |
|
current_url = danbooru_id_to_url( |
|
image_id, selected_ratings, api_username, api_key |
|
) |
|
if current_url is not None: |
|
image_urls.append(current_url) |
|
return list(zip(image_urls, captions)) |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
model = load_model(CONV_MODEL_REPO, CONV_MODEL_REVISION, CONV_FEXT_LAYER) |
|
images_ids = np.load("index/cosine_ids.npy") |
|
|
|
searcher = SimilaritySearcher(model=model, images_ids=images_ids) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
input = gr.Image(type="pil", label="Input") |
|
with gr.Column(): |
|
with gr.Row(): |
|
api_username = gr.Textbox(label="Danbooru API Username") |
|
api_key = gr.Textbox(label="Danbooru API Key") |
|
with gr.Row(): |
|
selected_ratings = gr.CheckboxGroup( |
|
choices=["General", "Sensitive", "Questionable", "Explicit"], |
|
value=["General", "Sensitive"], |
|
label="Ratings", |
|
) |
|
selected_metric = gr.Radio( |
|
choices=["cosine"], |
|
value="cosine", |
|
label="Metric selection", |
|
visible=False, |
|
) |
|
n_neighbours = gr.Slider( |
|
minimum=1, maximum=20, value=5, step=1, label="# of images" |
|
) |
|
find_btn = gr.Button("Find similar images") |
|
similar_images = gr.Gallery(label="Similar images") |
|
|
|
similar_images.style(grid=5) |
|
find_btn.click( |
|
fn=searcher.predict, |
|
inputs=[ |
|
input, |
|
selected_ratings, |
|
selected_metric, |
|
api_username, |
|
api_key, |
|
n_neighbours, |
|
], |
|
outputs=[similar_images], |
|
) |
|
|
|
demo.queue() |
|
demo.launch(share=args.share) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|