hysts HF staff commited on
Commit
d38fcf4
1 Parent(s): ccdd7cc
Files changed (2) hide show
  1. app.py +223 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import io
8
+ import os
9
+ import pathlib
10
+ import tarfile
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import PIL.Image
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ TITLE = 'TADNE (This Anime Does Not Exist) Image Selector'
18
+ DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/.
19
+
20
+ You can view images generated by the TADNE model with seed 0-99999.
21
+ You can filter images based on predictions by the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) model.
22
+ The original images are 512x512 in size, but here they are resized to 128x128.
23
+
24
+ Known issues:
25
+ - The `Seed` table in the output doesn't refresh properly in gradio 2.9.1. https://github.com/gradio-app/gradio/issues/921
26
+ '''
27
+ ARTICLE = None
28
+
29
+ TOKEN = os.environ['TOKEN']
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--theme', type=str)
35
+ parser.add_argument('--live', action='store_true')
36
+ parser.add_argument('--share', action='store_true')
37
+ parser.add_argument('--port', type=int)
38
+ parser.add_argument('--disable-queue',
39
+ dest='enable_queue',
40
+ action='store_false')
41
+ parser.add_argument('--allow-flagging', type=str, default='never')
42
+ parser.add_argument('--allow-screenshot', action='store_true')
43
+ return parser.parse_args()
44
+
45
+
46
+ def download_image_tarball(size: int, dirname: str) -> pathlib.Path:
47
+ path = hf_hub_download('hysts/TADNE-sample-images',
48
+ f'{size}/{dirname}.tar',
49
+ repo_type='dataset',
50
+ use_auth_token=TOKEN)
51
+ return path
52
+
53
+
54
+ def load_deepdanbooru_tag_dict() -> dict[str, int]:
55
+ path = hf_hub_download('hysts/DeepDanbooru',
56
+ 'tags.txt',
57
+ use_auth_token=TOKEN)
58
+ with open(path) as f:
59
+ tags = [line.strip() for line in f.readlines()]
60
+ return {tag: i for i, tag in enumerate(tags)}
61
+
62
+
63
+ def load_deepdanbooru_predictions(dirname: str) -> np.ndarray:
64
+ path = hf_hub_download('hysts/TADNE-sample-images',
65
+ f'prediction_results/deepdanbooru/{dirname}.npy',
66
+ repo_type='dataset',
67
+ use_auth_token=TOKEN)
68
+ return np.load(path)
69
+
70
+
71
+ def run(
72
+ general_tags: list[str],
73
+ hair_color_tags: list[str],
74
+ hair_style_tags: list[str],
75
+ image_color_tags: list[str],
76
+ score_threshold: float,
77
+ start_index: int,
78
+ nrows: int,
79
+ ncols: int,
80
+ image_size: int,
81
+ min_seed: int,
82
+ max_seed: int,
83
+ dirname: str,
84
+ tarball_path: pathlib.Path,
85
+ deepdanbooru_tag_dict: dict[str, int],
86
+ deepdanbooru_predictions: np.ndarray,
87
+ ) -> np.ndarray:
88
+ hair_color_tags = [f'{color}_hair' for color in hair_color_tags]
89
+
90
+ tags = general_tags + hair_color_tags + hair_style_tags + image_color_tags
91
+ tag_indices = [deepdanbooru_tag_dict[tag] for tag in tags]
92
+
93
+ conditions = deepdanbooru_predictions[:, tag_indices] > score_threshold
94
+ image_indices = np.arange(len(deepdanbooru_predictions))
95
+ image_indices = image_indices[conditions.all(axis=1)]
96
+
97
+ start_index = int(start_index)
98
+ num = nrows * ncols
99
+ seeds = []
100
+ images = []
101
+ dummy = np.ones((image_size, image_size, 3), dtype=np.uint8) * 255
102
+ with tarfile.TarFile(tarball_path) as tar_file:
103
+ for index in range(start_index, start_index + num):
104
+ if index >= len(image_indices):
105
+ seeds.append(-1)
106
+ images.append(dummy)
107
+ continue
108
+ image_index = image_indices[index]
109
+ seeds.append(image_index)
110
+ member = tar_file.getmember(f'{dirname}/{image_index:07d}.jpg')
111
+ with tar_file.extractfile(member) as f:
112
+ data = io.BytesIO(f.read())
113
+ image = PIL.Image.open(data)
114
+ image = np.asarray(image)
115
+ images.append(image)
116
+ res = np.asarray(images).reshape(nrows, ncols, image_size, image_size,
117
+ 3).transpose(0, 2, 1, 3, 4).reshape(
118
+ nrows * image_size,
119
+ ncols * image_size, 3)
120
+ seeds = np.asarray(seeds).reshape(nrows, ncols)
121
+
122
+ return len(image_indices), res, seeds
123
+
124
+
125
+ def main():
126
+ gr.close_all()
127
+
128
+ args = parse_args()
129
+
130
+ image_size = 128
131
+ min_seed = 0
132
+ max_seed = 99999
133
+ dirname = '0-99999'
134
+ tarball_path = download_image_tarball(image_size, dirname)
135
+
136
+ deepdanbooru_tag_dict = load_deepdanbooru_tag_dict()
137
+ deepdanbooru_predictions = load_deepdanbooru_predictions(dirname)
138
+
139
+ func = functools.partial(
140
+ run,
141
+ image_size=image_size,
142
+ min_seed=min_seed,
143
+ max_seed=max_seed,
144
+ dirname=dirname,
145
+ tarball_path=tarball_path,
146
+ deepdanbooru_tag_dict=deepdanbooru_tag_dict,
147
+ deepdanbooru_predictions=deepdanbooru_predictions,
148
+ )
149
+ func = functools.update_wrapper(func, run)
150
+
151
+ gr.Interface(
152
+ func,
153
+ [
154
+ gr.inputs.CheckboxGroup([
155
+ '1girl',
156
+ '1boy',
157
+ 'multiple_girls',
158
+ 'multiple_boys',
159
+ ],
160
+ label='General'),
161
+ gr.inputs.CheckboxGroup([
162
+ 'aqua',
163
+ 'black',
164
+ 'blonde',
165
+ 'blue',
166
+ 'brown',
167
+ 'green',
168
+ 'grey',
169
+ 'orange',
170
+ 'pink',
171
+ 'purple',
172
+ 'red',
173
+ 'silver',
174
+ 'white',
175
+ ],
176
+ label='Hair Color'),
177
+ gr.inputs.CheckboxGroup([
178
+ 'bangs',
179
+ 'curly_hair',
180
+ 'long_hair',
181
+ 'medium_hair',
182
+ 'messy_hair',
183
+ 'short_hair',
184
+ 'straight_hair',
185
+ 'twintails',
186
+ ],
187
+ label='Hair Style'),
188
+ gr.inputs.CheckboxGroup([
189
+ 'greyscale',
190
+ 'monochrome',
191
+ ],
192
+ label='Image Color'),
193
+ gr.inputs.Slider(0,
194
+ 1,
195
+ step=0.1,
196
+ default=0.5,
197
+ label='DeepDanbooru Score Threshold'),
198
+ gr.inputs.Number(default=0, label='Start Index'),
199
+ gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'),
200
+ gr.inputs.Slider(
201
+ 1, 10, step=1, default=5, label='Number of Columns'),
202
+ ],
203
+ [
204
+ gr.outputs.Textbox(type='number', label='Number of Found Images'),
205
+ gr.outputs.Image(type='numpy', label='Output'),
206
+ gr.outputs.Dataframe(type='numpy', label='Seed'),
207
+ ],
208
+ title=TITLE,
209
+ description=DESCRIPTION,
210
+ article=ARTICLE,
211
+ theme=args.theme,
212
+ allow_screenshot=args.allow_screenshot,
213
+ allow_flagging=args.allow_flagging,
214
+ live=args.live,
215
+ ).launch(
216
+ enable_queue=args.enable_queue,
217
+ server_port=args.port,
218
+ share=args.share,
219
+ )
220
+
221
+
222
+ if __name__ == '__main__':
223
+ main()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1