hysts HF staff commited on
Commit
11dad4e
β€’
1 Parent(s): e3f5d75
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +1 -1
  3. app.py +197 -0
  4. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸŒ–
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 2.9.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 2.9.3
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import io
8
+ import os
9
+ import pathlib
10
+ import tarfile
11
+
12
+ import deepdanbooru as dd
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ import numpy as np
16
+ import PIL.Image
17
+ import tensorflow as tf
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ TITLE = 'TADNE Image Search with DeepDanbooru'
21
+ DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/.
22
+
23
+ This app shows images similar to the query image from images generated
24
+ by the TADNE model with seed 0-99999.
25
+ Here, image similarity is measured by the L2 distance of the intermediate
26
+ features by the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)
27
+ model.
28
+
29
+ Known issues:
30
+ - The `Seed` table in the output doesn't refresh properly in gradio 2.9.1.
31
+ https://github.com/gradio-app/gradio/issues/921
32
+ '''
33
+ ARTICLE = None
34
+
35
+ TOKEN = os.environ['TOKEN']
36
+
37
+
38
+ def parse_args() -> argparse.Namespace:
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument('--theme', type=str, default='dark-grass')
41
+ parser.add_argument('--live', action='store_true')
42
+ parser.add_argument('--share', action='store_true')
43
+ parser.add_argument('--port', type=int)
44
+ parser.add_argument('--disable-queue',
45
+ dest='enable_queue',
46
+ action='store_false')
47
+ parser.add_argument('--allow-flagging', type=str, default='never')
48
+ parser.add_argument('--allow-screenshot', action='store_true')
49
+ return parser.parse_args()
50
+
51
+
52
+ def download_image_tarball(size: int, dirname: str) -> pathlib.Path:
53
+ path = hf_hub_download('hysts/TADNE-sample-images',
54
+ f'{size}/{dirname}.tar',
55
+ repo_type='dataset',
56
+ use_auth_token=TOKEN)
57
+ return path
58
+
59
+
60
+ def load_deepdanbooru_predictions(dirname: str) -> np.ndarray:
61
+ path = hf_hub_download(
62
+ 'hysts/TADNE-sample-images',
63
+ f'prediction_results/deepdanbooru/intermediate_features/{dirname}.npy',
64
+ repo_type='dataset',
65
+ use_auth_token=TOKEN)
66
+ return np.load(path)
67
+
68
+
69
+ def load_sample_image_paths() -> list[pathlib.Path]:
70
+ image_dir = pathlib.Path('images')
71
+ if not image_dir.exists():
72
+ dataset_repo = 'hysts/sample-images-TADNE'
73
+ path = huggingface_hub.hf_hub_download(dataset_repo,
74
+ 'images.tar.gz',
75
+ repo_type='dataset',
76
+ use_auth_token=TOKEN)
77
+ with tarfile.open(path) as f:
78
+ f.extractall()
79
+ return sorted(image_dir.glob('*'))
80
+
81
+
82
+ def create_model() -> tf.keras.Model:
83
+ path = huggingface_hub.hf_hub_download('hysts/DeepDanbooru',
84
+ 'model-resnet_custom_v3.h5',
85
+ use_auth_token=TOKEN)
86
+ model = tf.keras.models.load_model(path)
87
+ model = tf.keras.Model(model.input, model.layers[-4].output)
88
+ layer = tf.keras.layers.GlobalAveragePooling2D()
89
+ model = tf.keras.Sequential([model, layer])
90
+ return model
91
+
92
+
93
+ def predict(image: PIL.Image.Image, model: tf.keras.Model) -> np.ndarray:
94
+ _, height, width, _ = model.input_shape
95
+ image = np.asarray(image)
96
+ image = tf.image.resize(image,
97
+ size=(height, width),
98
+ method=tf.image.ResizeMethod.AREA,
99
+ preserve_aspect_ratio=True)
100
+ image = image.numpy()
101
+ image = dd.image.transform_and_pad_image(image, width, height)
102
+ image = image / 255.
103
+ features = model.predict(image[None, ...])[0]
104
+ features = features.astype(float)
105
+ return features
106
+
107
+
108
+ def run(
109
+ image: PIL.Image.Image,
110
+ nrows: int,
111
+ ncols: int,
112
+ image_size: int,
113
+ dirname: str,
114
+ tarball_path: pathlib.Path,
115
+ deepdanbooru_predictions: np.ndarray,
116
+ model: tf.keras.Model,
117
+ ) -> tuple[np.ndarray, np.ndarray]:
118
+ features = predict(image, model)
119
+ distances = ((deepdanbooru_predictions - features)**2).sum(axis=1)
120
+
121
+ image_indices = np.argsort(distances)
122
+
123
+ seeds = []
124
+ images = []
125
+ with tarfile.TarFile(tarball_path) as tar_file:
126
+ for index in range(nrows * ncols):
127
+ image_index = image_indices[index]
128
+ seeds.append(image_index)
129
+ member = tar_file.getmember(f'{dirname}/{image_index:07d}.jpg')
130
+ with tar_file.extractfile(member) as f:
131
+ data = io.BytesIO(f.read())
132
+ image = PIL.Image.open(data)
133
+ image = np.asarray(image)
134
+ images.append(image)
135
+ res = np.asarray(images).reshape(nrows, ncols, image_size, image_size,
136
+ 3).transpose(0, 2, 1, 3, 4).reshape(
137
+ nrows * image_size,
138
+ ncols * image_size, 3)
139
+ seeds = np.asarray(seeds).reshape(nrows, ncols)
140
+
141
+ return res, seeds
142
+
143
+
144
+ def main():
145
+ gr.close_all()
146
+
147
+ args = parse_args()
148
+
149
+ image_size = 128
150
+ dirname = '0-99999'
151
+ tarball_path = download_image_tarball(image_size, dirname)
152
+ deepdanbooru_predictions = load_deepdanbooru_predictions(dirname)
153
+
154
+ model = create_model()
155
+
156
+ image_paths = load_sample_image_paths()
157
+ examples = [[path.as_posix(), 2, 5] for path in image_paths]
158
+
159
+ func = functools.partial(
160
+ run,
161
+ image_size=image_size,
162
+ dirname=dirname,
163
+ tarball_path=tarball_path,
164
+ deepdanbooru_predictions=deepdanbooru_predictions,
165
+ model=model,
166
+ )
167
+ func = functools.update_wrapper(func, run)
168
+
169
+ gr.Interface(
170
+ func,
171
+ [
172
+ gr.inputs.Image(type='pil', label='Input'),
173
+ gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'),
174
+ gr.inputs.Slider(
175
+ 1, 10, step=1, default=5, label='Number of Columns'),
176
+ ],
177
+ [
178
+ gr.outputs.Image(type='numpy', label='Output'),
179
+ gr.outputs.Dataframe(type='numpy', label='Seed'),
180
+ ],
181
+ examples=examples,
182
+ title=TITLE,
183
+ description=DESCRIPTION,
184
+ article=ARTICLE,
185
+ theme=args.theme,
186
+ allow_screenshot=args.allow_screenshot,
187
+ allow_flagging=args.allow_flagging,
188
+ live=args.live,
189
+ ).launch(
190
+ enable_queue=args.enable_queue,
191
+ server_port=args.port,
192
+ share=args.share,
193
+ )
194
+
195
+
196
+ if __name__ == '__main__':
197
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pillow==9.1.0
2
+ tensorflow==2.8.0
3
+ git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru