hysts HF staff commited on
Commit
fb59cb8
1 Parent(s): 95971a5
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. DeepDanbooru +1 -0
  3. app.py +152 -0
  4. requirements.txt +2 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "DeepDanbooru"]
2
+ path = DeepDanbooru
3
+ url = https://github.com/KichangKim/DeepDanbooru
DeepDanbooru ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 92ba0b56be5eed0037e3f067bb9867f5ac691647
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import subprocess
10
+ import sys
11
+ import urllib
12
+ import zipfile
13
+ from typing import Callable
14
+
15
+ # workaround for https://github.com/gradio-app/gradio/issues/483
16
+ command = 'pip install -U gradio==2.7.0'
17
+ subprocess.call(command.split())
18
+
19
+ command = 'pip install -r DeepDanbooru/requirements.txt'
20
+ subprocess.call(command.split())
21
+ sys.path.insert(0, 'DeepDanbooru')
22
+
23
+ import deepdanbooru as dd
24
+ import gradio as gr
25
+ import huggingface_hub
26
+ import numpy as np
27
+ import PIL.Image
28
+ import tensorflow as tf
29
+
30
+ TOKEN = os.environ['TOKEN']
31
+
32
+ ZIP_PATH = 'data.zip'
33
+ TAG_PATH = 'tags.txt'
34
+ MODEL_PATH = 'model-resnet_custom_v3.h5'
35
+
36
+
37
+ def parse_args() -> argparse.Namespace:
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
40
+ parser.add_argument('--score-threshold', type=float, default=0.5)
41
+ parser.add_argument('--theme', type=str, default='dark-grass')
42
+ parser.add_argument('--live', action='store_true')
43
+ parser.add_argument('--share', action='store_true')
44
+ parser.add_argument('--port', type=int)
45
+ parser.add_argument('--disable-queue',
46
+ dest='enable_queue',
47
+ action='store_false')
48
+ parser.add_argument('--allow-flagging', type=str, default='never')
49
+ parser.add_argument('--allow-screenshot', action='store_true')
50
+ return parser.parse_args()
51
+
52
+
53
+ def download_sample_images() -> list[pathlib.Path]:
54
+ image_dir = pathlib.Path('samples')
55
+ image_dir.mkdir(exist_ok=True)
56
+
57
+ dataset_repo = 'hysts/sample-images-TADNE'
58
+ n_images = 36
59
+ paths = []
60
+ for index in range(n_images):
61
+ path = huggingface_hub.hf_hub_download(dataset_repo,
62
+ f'{index:02d}.jpg',
63
+ repo_type='dataset',
64
+ cache_dir=image_dir.as_posix(),
65
+ use_auth_token=TOKEN)
66
+ paths.append(pathlib.Path(path))
67
+ return paths
68
+
69
+
70
+ def download_model_data() -> None:
71
+ url = 'https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20200915-sgd-e30/deepdanbooru-v3-20200915-sgd-e30.zip'
72
+ urllib.request.urlretrieve(url, ZIP_PATH)
73
+ with zipfile.ZipFile(ZIP_PATH) as f:
74
+ f.extract(TAG_PATH)
75
+ f.extract(MODEL_PATH)
76
+
77
+
78
+ def predict(image: PIL.Image.Image, score_threshold: float, model,
79
+ labels: list[str]) -> dict[str, float]:
80
+ _, height, width, _ = model.input_shape
81
+ image = np.asarray(image)
82
+ image = tf.image.resize(image,
83
+ size=(height, width),
84
+ method=tf.image.ResizeMethod.AREA,
85
+ preserve_aspect_ratio=True)
86
+ image = image.numpy()
87
+ image = dd.image.transform_and_pad_image(image, width, height)
88
+ image = image / 255.
89
+ probs = model.predict(image[None, ...])[0]
90
+ probs = probs.astype(float)
91
+ res = dict()
92
+ for prob, label in zip(probs, labels):
93
+ if prob < score_threshold:
94
+ continue
95
+ res[label] = prob
96
+ return res
97
+
98
+
99
+ def main():
100
+ gr.close_all()
101
+
102
+ args = parse_args()
103
+
104
+ image_paths = download_sample_images()
105
+ examples = [[path.as_posix(), args.score_threshold]
106
+ for path in image_paths]
107
+
108
+ zip_path = pathlib.Path(ZIP_PATH)
109
+ if not zip_path.exists():
110
+ download_model_data()
111
+
112
+ model = tf.keras.models.load_model(MODEL_PATH)
113
+
114
+ with open(TAG_PATH) as f:
115
+ labels = [line.strip() for line in f.readlines()]
116
+
117
+ func = functools.partial(predict, model=model, labels=labels)
118
+ func = functools.update_wrapper(func, predict)
119
+
120
+ repo_url = 'https://github.com/KichangKim/DeepDanbooru'
121
+ title = 'KichangKim/DeepDanbooru'
122
+ description = f'A demo for {repo_url}'
123
+ article = None
124
+
125
+ gr.Interface(
126
+ func,
127
+ [
128
+ gr.inputs.Image(type='pil', label='Input'),
129
+ gr.inputs.Slider(0,
130
+ 1,
131
+ step=args.score_slider_step,
132
+ default=args.score_threshold,
133
+ label='Score Threshold'),
134
+ ],
135
+ gr.outputs.Label(label='Output'),
136
+ theme=args.theme,
137
+ title=title,
138
+ description=description,
139
+ article=article,
140
+ examples=examples,
141
+ allow_screenshot=args.allow_screenshot,
142
+ allow_flagging=args.allow_flagging,
143
+ live=args.live,
144
+ ).launch(
145
+ enable_queue=args.enable_queue,
146
+ server_port=args.port,
147
+ share=args.share,
148
+ )
149
+
150
+
151
+ if __name__ == '__main__':
152
+ main()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pillow>=9.0.0
2
+ tensorflow>=2.7.0