SmilingWolf's picture
Update CLIP-style model
74bd9c8
raw
history blame
No virus
9.61 kB
import json
import faiss
import flax
import gradio as gr
import jax
import numpy as np
import pandas as pd
import requests
from imgutils.tagging import wd14
from Models.CLIP import CLIP
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
pos = pos_img_embs + pos_tags_embs
faiss.normalize_L2(pos)
neg = neg_img_embs + neg_tags_embs
faiss.normalize_L2(neg)
result = pos - neg
faiss.normalize_L2(result)
return result
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
headers = {"User-Agent": "image_similarity_tool"}
ratings_to_letters = {
"General": "g",
"Sensitive": "s",
"Questionable": "q",
"Explicit": "e",
}
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
if api_username != "" and api_key != "":
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
r = requests.get(image_url, headers=headers)
if r.status_code != 200:
return None
content = json.loads(r.text)
image_url = content["large_file_url"] if "large_file_url" in content else None
image_url = image_url if content["rating"] in acceptable_ratings else None
return image_url
class Predictor:
def __init__(self):
self.loaded_variant = None
self.base_model = "wd-v1-4-convnext-tagger-v2"
self.model = CLIP()
self.tags_df = pd.read_csv("data/selected_tags.csv")
self.images_ids = np.load("index/cosine_ids.npy")
self.knn_index = faiss.read_index("index/cosine_knn.index")
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
def load_params(self, variant):
if self.loaded_variant == variant:
return
if variant == "CLIP":
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
data = f.read()
elif variant == "SigLIP":
with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f:
data = f.read()
self.params = flax.serialization.msgpack_restore(data)["model"]
self.loaded_variant = variant
def predict(
self,
pos_img_input,
neg_img_input,
positive_tags,
negative_tags,
selected_model,
selected_ratings,
n_neighbours,
api_username,
api_key,
):
tags_df = self.tags_df
model = self.model
self.load_params(selected_model)
num_classes = len(tags_df)
output_shape = model.out_units
pos_img_embs = np.zeros((1, output_shape), dtype=np.float32)
neg_img_embs = np.zeros((1, output_shape), dtype=np.float32)
pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
positive_tags = positive_tags.split(",")
negative_tags = negative_tags.split(",")
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
if pos_img_input is not None:
pos_img_embs = wd14.get_wd14_tags(
pos_img_input,
model_name="ConvNext",
fmt=("embedding"),
)
pos_img_embs = np.expand_dims(pos_img_embs, 0)
faiss.normalize_L2(pos_img_embs)
if neg_img_input is not None:
neg_img_embs = wd14.get_wd14_tags(
neg_img_input,
model_name="ConvNext",
fmt=("embedding"),
)
neg_img_embs = np.expand_dims(neg_img_embs, 0)
faiss.normalize_L2(neg_img_embs)
if len(positive_tags_idxs) > 0:
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][positive_tags_idxs] = 1
pos_tags_embs = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
pos_tags_embs = jax.device_get(pos_tags_embs)
faiss.normalize_L2(pos_tags_embs)
if len(negative_tags_idxs) > 0:
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][negative_tags_idxs] = 1
neg_tags_embs = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
neg_tags_embs = jax.device_get(neg_tags_embs)
faiss.normalize_L2(neg_tags_embs)
embeddings = combine_embeddings(
pos_img_embs,
pos_tags_embs,
neg_img_embs,
neg_tags_embs,
)
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = self.images_ids[indexes][0]
neighbours_ids = [int(x) for x in neighbours_ids]
captions = []
image_urls = []
for image_id, dist in zip(neighbours_ids, dists[0]):
current_url = danbooru_id_to_url(
image_id,
selected_ratings,
api_username,
api_key,
)
if current_url is not None:
image_urls.append(current_url)
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(image_urls, captions))
def main():
predictor = Predictor()
with gr.Blocks() as demo:
with gr.Row():
pos_img_input = gr.Image(type="pil", label="Positive input")
neg_img_input = gr.Image(type="pil", label="Negative input")
with gr.Row():
with gr.Column():
positive_tags = gr.Textbox(label="Positive tags")
negative_tags = gr.Textbox(label="Negative tags")
with gr.Column():
selected_model = gr.Radio(
choices=["CLIP", "SigLIP"],
value="CLIP",
label="Model",
)
n_neighbours = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="# of images",
)
with gr.Column():
selected_ratings = gr.CheckboxGroup(
choices=["General", "Sensitive", "Questionable", "Explicit"],
value=["General", "Sensitive"],
label="Ratings",
)
with gr.Row():
api_username = gr.Textbox(label="Danbooru API Username")
api_key = gr.Textbox(label="Danbooru API Key")
find_btn = gr.Button("Find similar images")
similar_images = gr.Gallery(label="Similar images", columns=[5])
examples = gr.Examples(
[
[
None,
None,
"marcille_donato",
"",
"CLIP",
["General", "Sensitive"],
5,
"",
"",
],
[
None,
None,
"yellow_eyes,red_horns",
"",
"CLIP",
["General", "Sensitive"],
5,
"",
"",
],
[
None,
None,
"artoria_pendragon_(fate),solo",
"green_eyes",
"CLIP",
["General", "Sensitive"],
5,
"",
"",
],
[
"examples/60378883_p0.jpg",
None,
"fujimaru_ritsuka_(female)",
"solo",
"CLIP",
["General", "Sensitive"],
5,
"",
"",
],
[
"examples/DaRlExxUwAAcUOS-orig.jpg",
"examples/46657164_p1.jpg",
"",
"",
"CLIP",
["General", "Sensitive"],
5,
"",
"",
],
],
inputs=[
pos_img_input,
neg_img_input,
positive_tags,
negative_tags,
selected_model,
selected_ratings,
n_neighbours,
api_username,
api_key,
],
outputs=[similar_images],
fn=predictor.predict,
run_on_click=True,
cache_examples=False,
)
find_btn.click(
fn=predictor.predict,
inputs=[
pos_img_input,
neg_img_input,
positive_tags,
negative_tags,
selected_model,
selected_ratings,
n_neighbours,
api_username,
api_key,
],
outputs=[similar_images],
)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()