SmilingWolf's picture
First commit
23fa49c
raw
history blame
4.23 kB
import json
import faiss
import flax
import gradio as gr
import jax
import numpy as np
import pandas as pd
import requests
from Models.CLIP import CLIP
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
headers = {"User-Agent": "image_similarity_tool"}
ratings_to_letters = {
"General": "g",
"Sensitive": "s",
"Questionable": "q",
"Explicit": "e",
}
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
if api_username != "" and api_key != "":
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
r = requests.get(image_url, headers=headers)
if r.status_code != 200:
return None
content = json.loads(r.text)
image_url = content["large_file_url"] if "large_file_url" in content else None
image_url = image_url if content["rating"] in acceptable_ratings else None
return image_url
class Predictor:
def __init__(self):
self.base_model = "wd-v1-4-convnext-tagger-v2"
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
data = f.read()
self.params = flax.serialization.msgpack_restore(data)["model"]
self.model = CLIP()
self.tags_df = pd.read_csv("data/selected_tags.csv")
self.images_ids = np.load("index/cosine_ids.npy")
self.knn_index = faiss.read_index("index/cosine_knn.index")
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
def predict(self, positive_tags, negative_tags, n_neighbours=5):
tags_df = self.tags_df
model = self.model
num_classes = len(tags_df)
positive_tags = positive_tags.split(",")
negative_tags = negative_tags.split(",")
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][positive_tags_idxs] = 1
emb_from_logits = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
emb_from_logits = jax.device_get(emb_from_logits)
if len(negative_tags_idxs) > 0:
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][negative_tags_idxs] = 1
neg_emb_from_logits = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
emb_from_logits = emb_from_logits - neg_emb_from_logits
faiss.normalize_L2(emb_from_logits)
dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
neighbours_ids = self.images_ids[indexes][0]
neighbours_ids = [int(x) for x in neighbours_ids]
captions = []
image_urls = []
for image_id, dist in zip(neighbours_ids, dists[0]):
current_url = danbooru_id_to_url(
image_id,
[
"General",
"Sensitive",
"Questionable",
"Explicit",
],
)
if current_url is not None:
image_urls.append(current_url)
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(image_urls, captions))
def main():
predictor = Predictor()
with gr.Blocks() as demo:
with gr.Row():
positive_tags = gr.Textbox(label="Positive tags")
negative_tags = gr.Textbox(label="Negative tags")
find_btn = gr.Button("Find similar images")
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=predictor.predict,
inputs=[positive_tags, negative_tags],
outputs=[similar_images],
)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()