SmilingWolf's picture
Add image support
0bd8f65
raw
history blame
8.73 kB
import json
import faiss
import flax
import gradio as gr
import jax
import numpy as np
import pandas as pd
import requests
from imgutils.tagging import wd14
from Models.CLIP import CLIP
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
pos = pos_img_embs + pos_tags_embs
neg = neg_img_embs + neg_tags_embs
result = pos - neg
return result
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
headers = {"User-Agent": "image_similarity_tool"}
ratings_to_letters = {
"General": "g",
"Sensitive": "s",
"Questionable": "q",
"Explicit": "e",
}
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
if api_username != "" and api_key != "":
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
r = requests.get(image_url, headers=headers)
if r.status_code != 200:
return None
content = json.loads(r.text)
image_url = content["large_file_url"] if "large_file_url" in content else None
image_url = image_url if content["rating"] in acceptable_ratings else None
return image_url
class Predictor:
def __init__(self):
self.base_model = "wd-v1-4-convnext-tagger-v2"
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
data = f.read()
self.params = flax.serialization.msgpack_restore(data)["model"]
self.model = CLIP()
self.tags_df = pd.read_csv("data/selected_tags.csv")
self.images_ids = np.load("index/cosine_ids.npy")
self.knn_index = faiss.read_index("index/cosine_knn.index")
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
def predict(
self,
pos_img_input,
neg_img_input,
positive_tags,
negative_tags,
selected_ratings,
n_neighbours,
api_username,
api_key,
):
tags_df = self.tags_df
model = self.model
num_classes = len(tags_df)
output_shape = model.out_units
pos_img_embs = np.zeros((1, output_shape), dtype=np.float32)
neg_img_embs = np.zeros((1, output_shape), dtype=np.float32)
pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
positive_tags = positive_tags.split(",")
negative_tags = negative_tags.split(",")
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
if pos_img_input is not None:
pos_img_embs = wd14.get_wd14_tags(
pos_img_input,
model_name="ConvNext",
fmt=("embedding"),
)
pos_img_embs = np.expand_dims(pos_img_embs, 0)
faiss.normalize_L2(pos_img_embs)
if neg_img_input is not None:
neg_img_embs = wd14.get_wd14_tags(
neg_img_input,
model_name="ConvNext",
fmt=("embedding"),
)
neg_img_embs = np.expand_dims(neg_img_embs, 0)
faiss.normalize_L2(neg_img_embs)
if len(positive_tags_idxs) > 0:
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][positive_tags_idxs] = 1
pos_tags_embs = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
pos_tags_embs = jax.device_get(pos_tags_embs)
faiss.normalize_L2(pos_tags_embs)
if len(negative_tags_idxs) > 0:
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][negative_tags_idxs] = 1
neg_tags_embs = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
neg_tags_embs = jax.device_get(neg_tags_embs)
faiss.normalize_L2(neg_tags_embs)
embeddings = combine_embeddings(
pos_img_embs,
pos_tags_embs,
neg_img_embs,
neg_tags_embs,
)
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = self.images_ids[indexes][0]
neighbours_ids = [int(x) for x in neighbours_ids]
captions = []
image_urls = []
for image_id, dist in zip(neighbours_ids, dists[0]):
current_url = danbooru_id_to_url(
image_id,
selected_ratings,
api_username,
api_key,
)
if current_url is not None:
image_urls.append(current_url)
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(image_urls, captions))
def main():
predictor = Predictor()
with gr.Blocks() as demo:
with gr.Row():
pos_img_input = gr.Image(type="pil", label="Positive input")
neg_img_input = gr.Image(type="pil", label="Negative input")
with gr.Row():
with gr.Column():
positive_tags = gr.Textbox(label="Positive tags")
negative_tags = gr.Textbox(label="Negative tags")
with gr.Column():
selected_ratings = gr.CheckboxGroup(
choices=["General", "Sensitive", "Questionable", "Explicit"],
value=["General", "Sensitive"],
label="Ratings",
)
n_neighbours = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="# of images",
)
with gr.Column():
api_username = gr.Textbox(label="Danbooru API Username")
api_key = gr.Textbox(label="Danbooru API Key")
find_btn = gr.Button("Find similar images")
similar_images = gr.Gallery(label="Similar images", columns=[5])
examples = gr.Examples(
[
[
None,
None,
"marcille_donato",
"",
["General", "Sensitive"],
5,
"",
"",
],
[
None,
None,
"yellow_eyes,red_horns",
"",
["General", "Sensitive"],
5,
"",
"",
],
[
None,
None,
"artoria_pendragon_(fate),solo",
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
["General", "Sensitive"],
5,
"",
"",
],
[
"examples/60378883_p0.jpg",
None,
"fujimaru_ritsuka_(female)",
"solo",
["General", "Sensitive"],
5,
"",
"",
],
[
"examples/DaRlExxUwAAcUOS-orig.jpg",
"examples/46657164_p1.jpg",
"",
"",
["General", "Sensitive"],
5,
"",
"",
],
],
inputs=[
pos_img_input,
neg_img_input,
positive_tags,
negative_tags,
selected_ratings,
n_neighbours,
api_username,
api_key,
],
outputs=[similar_images],
fn=predictor.predict,
run_on_click=True,
cache_examples=False,
)
find_btn.click(
fn=predictor.predict,
inputs=[
pos_img_input,
neg_img_input,
positive_tags,
negative_tags,
selected_ratings,
n_neighbours,
api_username,
api_key,
],
outputs=[similar_images],
)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()