#!/usr/bin/env python from __future__ import annotations import io import pathlib import tarfile import deepdanbooru as dd import gradio as gr import huggingface_hub import numpy as np import PIL.Image import tensorflow as tf from huggingface_hub import hf_hub_download TITLE = "TADNE Image Search with DeepDanbooru" DESCRIPTION = """The original TADNE site is https://thisanimedoesnotexist.ai/. This app shows images similar to the query image from images generated by the TADNE model with seed 0-99999. Here, image similarity is measured by the L2 distance of the intermediate features by the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) model. The resolution of the output images in this app is 128x128, but you can check the original 512x512 images from URLs like https://thisanimedoesnotexist.ai/slider.html?seed=10000 using the output seeds. Expected execution time on Hugging Face Spaces: 7s Related Apps: - [TADNE](https://huggingface.co/spaces/hysts/TADNE) - [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer) - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector) - [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation) - [DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru) """ def load_deepdanbooru_predictions(dirname: str) -> np.ndarray: path = hf_hub_download( "hysts/TADNE-sample-images", f"prediction_results/deepdanbooru/intermediate_features/{dirname}.npy", repo_type="dataset", ) return np.load(path) def load_sample_image_paths() -> list[pathlib.Path]: image_dir = pathlib.Path("images") if not image_dir.exists(): dataset_repo = "hysts/sample-images-TADNE" path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset") with tarfile.open(path) as f: f.extractall() return sorted(image_dir.glob("*")) def create_model() -> tf.keras.Model: path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5") model = tf.keras.models.load_model(path) model = tf.keras.Model(model.input, model.layers[-4].output) layer = tf.keras.layers.GlobalAveragePooling2D() model = tf.keras.Sequential([model, layer]) return model image_size = 128 dirname = "0-99999" tarball_path = hf_hub_download("hysts/TADNE-sample-images", f"{image_size}/{dirname}.tar", repo_type="dataset") deepdanbooru_predictions = load_deepdanbooru_predictions(dirname) model = create_model() def predict(image: PIL.Image.Image) -> np.ndarray: _, height, width, _ = model.input_shape image = np.asarray(image) image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True) image = image.numpy() image = dd.image.transform_and_pad_image(image, width, height) image = image / 255.0 features = model.predict(image[None, ...])[0] features = features.astype(float) return features def run( image: PIL.Image.Image, nrows: int, ncols: int, ) -> tuple[np.ndarray, np.ndarray]: features = predict(image) distances = ((deepdanbooru_predictions - features) ** 2).sum(axis=1) image_indices = np.argsort(distances) seeds = [] images = [] with tarfile.TarFile(tarball_path) as tar_file: for index in range(nrows * ncols): image_index = image_indices[index] seeds.append(image_index) member = tar_file.getmember(f"{dirname}/{image_index:07d}.jpg") with tar_file.extractfile(member) as f: # type: ignore data = io.BytesIO(f.read()) image = PIL.Image.open(data) image = np.asarray(image) images.append(image) res = ( np.asarray(images) .reshape(nrows, ncols, image_size, image_size, 3) .transpose(0, 2, 1, 3, 4) .reshape(nrows * image_size, ncols * image_size, 3) ) seeds = np.asarray(seeds).reshape(nrows, ncols) return res, seeds image_paths = load_sample_image_paths() examples = [[path.as_posix(), 2, 5] for path in image_paths] demo = gr.Interface( fn=run, inputs=[ gr.Image(label="Input", type="pil"), gr.Slider(label="Number of Rows", minimum=1, maximum=10, step=1, value=2), gr.Slider(label="Number of Columns", minimum=1, maximum=10, step=1, value=2), ], outputs=[ gr.Image(label="Output"), gr.Dataframe(label="Seed"), ], examples=examples, title=TITLE, description=DESCRIPTION, ) if __name__ == "__main__": demo.queue().launch()