|
import json |
|
|
|
import faiss |
|
import flax |
|
import gradio as gr |
|
import jax |
|
import numpy as np |
|
import pandas as pd |
|
import requests |
|
from imgutils.tagging import wd14 |
|
|
|
from Models.CLIP import CLIP |
|
|
|
|
|
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs): |
|
pos = pos_img_embs + pos_tags_embs |
|
faiss.normalize_L2(pos) |
|
|
|
neg = neg_img_embs + neg_tags_embs |
|
faiss.normalize_L2(neg) |
|
|
|
result = pos - neg |
|
faiss.normalize_L2(result) |
|
return result |
|
|
|
|
|
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): |
|
headers = {"User-Agent": "image_similarity_tool"} |
|
ratings_to_letters = { |
|
"General": "g", |
|
"Sensitive": "s", |
|
"Questionable": "q", |
|
"Explicit": "e", |
|
} |
|
|
|
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] |
|
|
|
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" |
|
if api_username != "" and api_key != "": |
|
image_url = f"{image_url}?api_key={api_key}&login={api_username}" |
|
|
|
r = requests.get(image_url, headers=headers) |
|
if r.status_code != 200: |
|
return None |
|
|
|
content = json.loads(r.text) |
|
image_url = content["large_file_url"] if "large_file_url" in content else None |
|
image_url = image_url if content["rating"] in acceptable_ratings else None |
|
return image_url |
|
|
|
|
|
class Predictor: |
|
def __init__(self): |
|
self.loaded_variant = None |
|
self.base_model = "wd-v1-4-convnext-tagger-v2" |
|
|
|
self.model = CLIP() |
|
|
|
self.tags_df = pd.read_csv("data/selected_tags.csv") |
|
|
|
self.images_ids = np.load("index/cosine_ids.npy") |
|
self.knn_index = faiss.read_index("index/cosine_knn.index") |
|
|
|
config = json.loads(open("index/cosine_infos.json").read())["index_param"] |
|
faiss.ParameterSpace().set_index_parameters(self.knn_index, config) |
|
|
|
def load_params(self, variant): |
|
if self.loaded_variant == variant: |
|
return |
|
|
|
if variant == "CLIP": |
|
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: |
|
data = f.read() |
|
elif variant == "SigLIP": |
|
with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f: |
|
data = f.read() |
|
|
|
self.params = flax.serialization.msgpack_restore(data)["model"] |
|
self.loaded_variant = variant |
|
|
|
def predict( |
|
self, |
|
pos_img_input, |
|
neg_img_input, |
|
positive_tags, |
|
negative_tags, |
|
selected_model, |
|
selected_ratings, |
|
n_neighbours, |
|
api_username, |
|
api_key, |
|
): |
|
tags_df = self.tags_df |
|
model = self.model |
|
|
|
self.load_params(selected_model) |
|
|
|
num_classes = len(tags_df) |
|
|
|
output_shape = model.out_units |
|
pos_img_embs = np.zeros((1, output_shape), dtype=np.float32) |
|
neg_img_embs = np.zeros((1, output_shape), dtype=np.float32) |
|
pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32) |
|
neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32) |
|
|
|
positive_tags = positive_tags.split(",") |
|
negative_tags = negative_tags.split(",") |
|
|
|
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() |
|
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() |
|
|
|
if pos_img_input is not None: |
|
pos_img_embs = wd14.get_wd14_tags( |
|
pos_img_input, |
|
model_name="ConvNext", |
|
fmt=("embedding"), |
|
) |
|
pos_img_embs = np.expand_dims(pos_img_embs, 0) |
|
faiss.normalize_L2(pos_img_embs) |
|
|
|
if neg_img_input is not None: |
|
neg_img_embs = wd14.get_wd14_tags( |
|
neg_img_input, |
|
model_name="ConvNext", |
|
fmt=("embedding"), |
|
) |
|
neg_img_embs = np.expand_dims(neg_img_embs, 0) |
|
faiss.normalize_L2(neg_img_embs) |
|
|
|
if len(positive_tags_idxs) > 0: |
|
tags = np.zeros((1, num_classes), dtype=np.float32) |
|
tags[0][positive_tags_idxs] = 1 |
|
|
|
pos_tags_embs = model.apply( |
|
{"params": self.params}, |
|
tags, |
|
method=model.encode_text, |
|
) |
|
pos_tags_embs = jax.device_get(pos_tags_embs) |
|
faiss.normalize_L2(pos_tags_embs) |
|
|
|
if len(negative_tags_idxs) > 0: |
|
tags = np.zeros((1, num_classes), dtype=np.float32) |
|
tags[0][negative_tags_idxs] = 1 |
|
|
|
neg_tags_embs = model.apply( |
|
{"params": self.params}, |
|
tags, |
|
method=model.encode_text, |
|
) |
|
neg_tags_embs = jax.device_get(neg_tags_embs) |
|
faiss.normalize_L2(neg_tags_embs) |
|
|
|
embeddings = combine_embeddings( |
|
pos_img_embs, |
|
pos_tags_embs, |
|
neg_img_embs, |
|
neg_tags_embs, |
|
) |
|
|
|
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) |
|
neighbours_ids = self.images_ids[indexes][0] |
|
neighbours_ids = [int(x) for x in neighbours_ids] |
|
|
|
captions = [] |
|
image_urls = [] |
|
for image_id, dist in zip(neighbours_ids, dists[0]): |
|
current_url = danbooru_id_to_url( |
|
image_id, |
|
selected_ratings, |
|
api_username, |
|
api_key, |
|
) |
|
if current_url is not None: |
|
image_urls.append(current_url) |
|
captions.append(f"{image_id}/{dist:.2f}") |
|
return list(zip(image_urls, captions)) |
|
|
|
|
|
def main(): |
|
predictor = Predictor() |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
pos_img_input = gr.Image(type="pil", label="Positive input") |
|
neg_img_input = gr.Image(type="pil", label="Negative input") |
|
with gr.Row(): |
|
with gr.Column(): |
|
positive_tags = gr.Textbox(label="Positive tags") |
|
negative_tags = gr.Textbox(label="Negative tags") |
|
selected_model = gr.Radio( |
|
choices=["CLIP", "SigLIP"], |
|
value="CLIP", |
|
label="Tags embedding model", |
|
) |
|
with gr.Column(): |
|
selected_ratings = gr.CheckboxGroup( |
|
choices=["General", "Sensitive", "Questionable", "Explicit"], |
|
value=["General", "Sensitive"], |
|
label="Ratings", |
|
) |
|
n_neighbours = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
value=5, |
|
step=1, |
|
label="# of images", |
|
) |
|
with gr.Row(): |
|
api_username = gr.Textbox(label="Danbooru API Username") |
|
api_key = gr.Textbox(label="Danbooru API Key") |
|
|
|
find_btn = gr.Button("Find similar images") |
|
|
|
similar_images = gr.Gallery(label="Similar images", columns=[5]) |
|
|
|
examples = gr.Examples( |
|
[ |
|
[ |
|
None, |
|
None, |
|
"marcille_donato", |
|
"", |
|
"CLIP", |
|
["General", "Sensitive"], |
|
5, |
|
"", |
|
"", |
|
], |
|
[ |
|
None, |
|
None, |
|
"yellow_eyes,red_horns", |
|
"", |
|
"CLIP", |
|
["General", "Sensitive"], |
|
5, |
|
"", |
|
"", |
|
], |
|
[ |
|
None, |
|
None, |
|
"artoria_pendragon_(fate),solo", |
|
"green_eyes", |
|
"CLIP", |
|
["General", "Sensitive"], |
|
5, |
|
"", |
|
"", |
|
], |
|
[ |
|
"examples/60378883_p0.jpg", |
|
None, |
|
"fujimaru_ritsuka_(female)", |
|
"solo", |
|
"CLIP", |
|
["General", "Sensitive"], |
|
5, |
|
"", |
|
"", |
|
], |
|
[ |
|
"examples/DaRlExxUwAAcUOS-orig.jpg", |
|
"examples/46657164_p1.jpg", |
|
"", |
|
"", |
|
"CLIP", |
|
["General", "Sensitive"], |
|
5, |
|
"", |
|
"", |
|
], |
|
], |
|
inputs=[ |
|
pos_img_input, |
|
neg_img_input, |
|
positive_tags, |
|
negative_tags, |
|
selected_model, |
|
selected_ratings, |
|
n_neighbours, |
|
api_username, |
|
api_key, |
|
], |
|
outputs=[similar_images], |
|
fn=predictor.predict, |
|
run_on_click=True, |
|
cache_examples=False, |
|
) |
|
|
|
find_btn.click( |
|
fn=predictor.predict, |
|
inputs=[ |
|
pos_img_input, |
|
neg_img_input, |
|
positive_tags, |
|
negative_tags, |
|
selected_model, |
|
selected_ratings, |
|
n_neighbours, |
|
api_username, |
|
api_key, |
|
], |
|
outputs=[similar_images], |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|