File size: 4,280 Bytes
86f89b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6f7d7
 
 
 
 
 
 
 
 
86f89b5
9e6f7d7
86f89b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb346e5
86f89b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python

from __future__ import annotations

import argparse
import functools
import io
import os
import pathlib
import tarfile

import gradio as gr
import numpy as np
import PIL.Image
from huggingface_hub import hf_hub_download

TITLE = 'TADNE (This Anime Does Not Exist) Image Viewer'
DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/.

You can view images generated by the TADNE model with seed 0-99999.
The original images are 512x512 in size, but they are resized to 128x128 here.

Expected execution time on Hugging Face Spaces: 4s

Related Apps:
- [TADNE](https://huggingface.co/spaces/hysts/TADNE)
- [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer)
- [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
- [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation)
- [TADNE Image Search with DeepDanbooru](https://huggingface.co/spaces/hysts/TADNE-image-search-with-DeepDanbooru)
'''
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.tadne-image-viewer" alt="visitor badge"/></center>'

TOKEN = os.environ['TOKEN']


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--theme', type=str)
    parser.add_argument('--live', action='store_true')
    parser.add_argument('--share', action='store_true')
    parser.add_argument('--port', type=int)
    parser.add_argument('--disable-queue',
                        dest='enable_queue',
                        action='store_false')
    parser.add_argument('--allow-flagging', type=str, default='never')
    return parser.parse_args()


def download_image_tarball(size: int, dirname: str) -> pathlib.Path:
    path = hf_hub_download('hysts/TADNE-sample-images',
                           f'{size}/{dirname}.tar',
                           repo_type='dataset',
                           use_auth_token=TOKEN)
    return path


def run(start_seed: int, nrows: int, ncols: int, image_size: int,
        min_seed: int, max_seed: int, dirname: str,
        tarball_path: pathlib.Path) -> np.ndarray:
    start_seed = int(start_seed)
    num = nrows * ncols
    images = []
    dummy = np.ones((image_size, image_size, 3), dtype=np.uint8) * 255
    with tarfile.TarFile(tarball_path) as tar_file:
        for seed in range(start_seed, start_seed + num):
            if not min_seed <= seed <= max_seed:
                images.append(dummy)
                continue
            member = tar_file.getmember(f'{dirname}/{seed:07d}.jpg')
            with tar_file.extractfile(member) as f:
                data = io.BytesIO(f.read())
            image = PIL.Image.open(data)
            image = np.asarray(image)
            images.append(image)
    res = np.asarray(images).reshape(nrows, ncols, image_size, image_size,
                                     3).transpose(0, 2, 1, 3, 4).reshape(
                                         nrows * image_size,
                                         ncols * image_size, 3)
    return res


def main():
    args = parse_args()

    image_size = 128
    min_seed = 0
    max_seed = 99999
    dirname = '0-99999'
    tarball_path = download_image_tarball(image_size, dirname)

    func = functools.partial(run,
                             image_size=image_size,
                             min_seed=min_seed,
                             max_seed=max_seed,
                             dirname=dirname,
                             tarball_path=tarball_path)
    func = functools.update_wrapper(func, run)

    gr.Interface(
        func,
        [
            gr.inputs.Number(default=0, label='Start Seed'),
            gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'),
            gr.inputs.Slider(
                1, 10, step=1, default=5, label='Number of Columns'),
        ],
        gr.outputs.Image(type='numpy', label='Output'),
        title=TITLE,
        description=DESCRIPTION,
        article=ARTICLE,
        theme=args.theme,
        allow_flagging=args.allow_flagging,
        live=args.live,
    ).launch(
        enable_queue=args.enable_queue,
        server_port=args.port,
        share=args.share,
    )


if __name__ == '__main__':
    main()