#!/usr/bin/env python from __future__ import annotations import argparse import functools import io import os import pathlib import tarfile import gradio as gr import numpy as np import PIL.Image from huggingface_hub import hf_hub_download TITLE = 'TADNE (This Anime Does Not Exist) Image Selector' DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/. You can view images generated by the TADNE model with seed 0-99999. You can filter images based on predictions by the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) model. The resolution of the output images in this app is 128x128, but you can check the original 512x512 images from URLs like https://thisanimedoesnotexist.ai/slider.html?seed=10000 using the output seeds. Expected execution time on Hugging Face Spaces: 4s Related Apps: - [TADNE](https://huggingface.co/spaces/hysts/TADNE) - [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer) - [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation) - [TADNE Image Search with DeepDanbooru](https://huggingface.co/spaces/hysts/TADNE-image-search-with-DeepDanbooru) - [DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru) ''' ARTICLE = '
visitor badge
' TOKEN = os.environ['TOKEN'] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') return parser.parse_args() def download_image_tarball(size: int, dirname: str) -> pathlib.Path: path = hf_hub_download('hysts/TADNE-sample-images', f'{size}/{dirname}.tar', repo_type='dataset', use_auth_token=TOKEN) return path def load_deepdanbooru_tag_dict() -> dict[str, int]: path = hf_hub_download('hysts/DeepDanbooru', 'tags.txt', use_auth_token=TOKEN) with open(path) as f: tags = [line.strip() for line in f.readlines()] return {tag: i for i, tag in enumerate(tags)} def load_deepdanbooru_predictions(dirname: str) -> np.ndarray: path = hf_hub_download('hysts/TADNE-sample-images', f'prediction_results/deepdanbooru/{dirname}.npy', repo_type='dataset', use_auth_token=TOKEN) return np.load(path) def run( general_tags: list[str], hair_color_tags: list[str], hair_style_tags: list[str], eye_color_tags: list[str], image_color_tags: list[str], other_tags: list[str], additional_tags: str, score_threshold: float, start_index: int, nrows: int, ncols: int, image_size: int, min_seed: int, max_seed: int, dirname: str, tarball_path: pathlib.Path, deepdanbooru_tag_dict: dict[str, int], deepdanbooru_predictions: np.ndarray, ) -> tuple[int, np.ndarray, np.ndarray]: hair_color_tags = [f'{color}_hair' for color in hair_color_tags] eye_color_tags = [f'{color}_eyes' for color in eye_color_tags] additional_tags = additional_tags.split(',') tags = general_tags + hair_color_tags + hair_style_tags + \ eye_color_tags + image_color_tags + other_tags + additional_tags missing_tags = [tag for tag in tags if tag not in deepdanbooru_tag_dict] tag_indices = [ deepdanbooru_tag_dict[tag] for tag in tags if tag in deepdanbooru_tag_dict ] conditions = deepdanbooru_predictions[:, tag_indices] > score_threshold image_indices = np.arange(len(deepdanbooru_predictions)) image_indices = image_indices[conditions.all(axis=1)] start_index = int(start_index) num = nrows * ncols seeds = [] images = [] dummy = np.ones((image_size, image_size, 3), dtype=np.uint8) * 255 with tarfile.TarFile(tarball_path) as tar_file: for index in range(start_index, start_index + num): if index >= len(image_indices): seeds.append(np.nan) images.append(dummy) continue image_index = image_indices[index] seeds.append(image_index) member = tar_file.getmember(f'{dirname}/{image_index:07d}.jpg') with tar_file.extractfile(member) as f: data = io.BytesIO(f.read()) image = PIL.Image.open(data) image = np.asarray(image) images.append(image) res = np.asarray(images).reshape(nrows, ncols, image_size, image_size, 3).transpose(0, 2, 1, 3, 4).reshape( nrows * image_size, ncols * image_size, 3) seeds = np.asarray(seeds).reshape(nrows, ncols) return len(image_indices), res, seeds, ','.join(missing_tags) def main(): args = parse_args() image_size = 128 min_seed = 0 max_seed = 99999 dirname = '0-99999' tarball_path = download_image_tarball(image_size, dirname) deepdanbooru_tag_dict = load_deepdanbooru_tag_dict() deepdanbooru_predictions = load_deepdanbooru_predictions(dirname) func = functools.partial( run, image_size=image_size, min_seed=min_seed, max_seed=max_seed, dirname=dirname, tarball_path=tarball_path, deepdanbooru_tag_dict=deepdanbooru_tag_dict, deepdanbooru_predictions=deepdanbooru_predictions, ) func = functools.update_wrapper(func, run) gr.Interface( func, [ gr.inputs.CheckboxGroup([ '1girl', '1boy', 'multiple_girls', 'multiple_boys', 'looking_at_viewer', ], label='General'), gr.inputs.CheckboxGroup([ 'aqua', 'black', 'blonde', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'silver', 'white', ], label='Hair Color'), gr.inputs.CheckboxGroup([ 'bangs', 'curly_hair', 'long_hair', 'medium_hair', 'messy_hair', 'ponytail', 'short_hair', 'straight_hair', 'twintails', ], label='Hair Style'), gr.inputs.CheckboxGroup([ 'aqua', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow', ], label='Eye Color'), gr.inputs.CheckboxGroup([ 'greyscale', 'monochrome', ], label='Image Color'), gr.inputs.CheckboxGroup([ 'animal_ears', 'closed_eyes', 'full_body', 'hat', 'smile', ], label='Others'), gr.inputs.Textbox(label='Additional Tags'), gr.inputs.Slider(0, 1, step=0.1, default=0.5, label='DeepDanbooru Score Threshold'), gr.inputs.Number(default=0, label='Start Index'), gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'), gr.inputs.Slider( 1, 10, step=1, default=5, label='Number of Columns'), ], [ gr.outputs.Textbox(type='number', label='Number of Found Images'), gr.outputs.Image(type='numpy', label='Output'), gr.outputs.Dataframe(type='numpy', label='Seed'), gr.outputs.Textbox(type='str', label='Missing Tags'), ], title=TITLE, description=DESCRIPTION, article=ARTICLE, theme=args.theme, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()