hysts HF staff commited on
Commit
01b28b7
1 Parent(s): 5d94de1
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +42 -67
  3. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏃
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.6
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -27,7 +27,7 @@ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gr
27
  Can be either `gradio`, `streamlit`, or `static`
28
 
29
  `sdk_version` : _string_
30
- Only applicable for `streamlit` SDK.
31
  See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
 
33
  `app_file`: _string_
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.32
8
  app_file: app.py
9
  pinned: false
10
  ---
 
27
  Can be either `gradio`, `streamlit`, or `static`
28
 
29
  `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
  See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
 
33
  `app_file`: _string_
app.py CHANGED
@@ -2,8 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
- import functools
7
  import os
8
  import pathlib
9
  import tarfile
@@ -15,56 +13,42 @@ import numpy as np
15
  import PIL.Image
16
  import tensorflow as tf
17
 
18
- TITLE = 'KichangKim/DeepDanbooru'
19
- DESCRIPTION = 'This is an unofficial demo for https://github.com/KichangKim/DeepDanbooru.'
20
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.deepdanbooru" alt="visitor badge"/></center>'
21
-
22
- HF_TOKEN = os.environ['HF_TOKEN']
23
- MODEL_REPO = 'hysts/DeepDanbooru'
24
- MODEL_FILENAME = 'model-resnet_custom_v3.h5'
25
- LABEL_FILENAME = 'tags.txt'
26
-
27
-
28
- def parse_args() -> argparse.Namespace:
29
- parser = argparse.ArgumentParser()
30
- parser.add_argument('--score-slider-step', type=float, default=0.05)
31
- parser.add_argument('--score-threshold', type=float, default=0.5)
32
- parser.add_argument('--share', action='store_true')
33
- return parser.parse_args()
34
 
35
 
36
  def load_sample_image_paths() -> list[pathlib.Path]:
37
  image_dir = pathlib.Path('images')
38
  if not image_dir.exists():
39
- dataset_repo = 'hysts/sample-images-TADNE'
40
- path = huggingface_hub.hf_hub_download(dataset_repo,
41
- 'images.tar.gz',
42
- repo_type='dataset',
43
- use_auth_token=HF_TOKEN)
44
  with tarfile.open(path) as f:
45
  f.extractall()
46
  return sorted(image_dir.glob('*'))
47
 
48
 
49
  def load_model() -> tf.keras.Model:
50
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
51
- MODEL_FILENAME,
52
- use_auth_token=HF_TOKEN)
53
  model = tf.keras.models.load_model(path)
54
  return model
55
 
56
 
57
  def load_labels() -> list[str]:
58
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
59
- LABEL_FILENAME,
60
- use_auth_token=HF_TOKEN)
61
  with open(path) as f:
62
  labels = [line.strip() for line in f.readlines()]
63
  return labels
64
 
65
 
66
- def predict(image: PIL.Image.Image, score_threshold: float,
67
- model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
 
 
 
 
68
  _, height, width, _ = model.input_shape
69
  image = np.asarray(image)
70
  image = tf.image.resize(image,
@@ -84,39 +68,30 @@ def predict(image: PIL.Image.Image, score_threshold: float,
84
  return res
85
 
86
 
87
- def main():
88
- args = parse_args()
89
-
90
- image_paths = load_sample_image_paths()
91
- examples = [[path.as_posix(), args.score_threshold]
92
- for path in image_paths]
93
-
94
- model = load_model()
95
- labels = load_labels()
96
-
97
- func = functools.partial(predict, model=model, labels=labels)
98
-
99
- gr.Interface(
100
- func,
101
- [
102
- gr.Image(type='pil', label='Input'),
103
- gr.Slider(0,
104
- 1,
105
- step=args.score_slider_step,
106
- value=args.score_threshold,
107
- label='Score Threshold'),
108
- ],
109
- gr.Label(label='Output'),
110
- examples=examples,
111
- title=TITLE,
112
- description=DESCRIPTION,
113
- article=ARTICLE,
114
- allow_flagging='never',
115
- ).launch(
116
- enable_queue=True,
117
- share=args.share,
118
- )
119
-
120
-
121
- if __name__ == '__main__':
122
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import os
6
  import pathlib
7
  import tarfile
 
13
  import PIL.Image
14
  import tensorflow as tf
15
 
16
+ DESCRIPTION = '# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def load_sample_image_paths() -> list[pathlib.Path]:
20
  image_dir = pathlib.Path('images')
21
  if not image_dir.exists():
22
+ path = huggingface_hub.hf_hub_download(
23
+ 'public-data/sample-images-TADNE',
24
+ 'images.tar.gz',
25
+ repo_type='dataset')
 
26
  with tarfile.open(path) as f:
27
  f.extractall()
28
  return sorted(image_dir.glob('*'))
29
 
30
 
31
  def load_model() -> tf.keras.Model:
32
+ path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
33
+ 'model-resnet_custom_v3.h5')
 
34
  model = tf.keras.models.load_model(path)
35
  return model
36
 
37
 
38
  def load_labels() -> list[str]:
39
+ path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
40
+ 'tags.txt')
 
41
  with open(path) as f:
42
  labels = [line.strip() for line in f.readlines()]
43
  return labels
44
 
45
 
46
+ model = load_model()
47
+ labels = load_labels()
48
+
49
+
50
+ def predict(image: PIL.Image.Image,
51
+ score_threshold: float) -> dict[str, float]:
52
  _, height, width, _ = model.input_shape
53
  image = np.asarray(image)
54
  image = tf.image.resize(image,
 
68
  return res
69
 
70
 
71
+ image_paths = load_sample_image_paths()
72
+ examples = [[path.as_posix(), 0.5] for path in image_paths]
73
+
74
+ with gr.Blocks(css='style.css') as demo:
75
+ gr.Markdown(DESCRIPTION)
76
+ with gr.Row():
77
+ with gr.Column():
78
+ image = gr.Image(label='Input', type='pil')
79
+ score_threshold = gr.Slider(label='Score threshold',
80
+ minimum=0,
81
+ maximum=1,
82
+ step=0.05,
83
+ value=0.5)
84
+ run_button = gr.Button('Run')
85
+ with gr.Column():
86
+ result = gr.Label(label='Output')
87
+ gr.Examples(examples=examples,
88
+ inputs=[image, score_threshold],
89
+ outputs=result,
90
+ fn=predict,
91
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
92
+
93
+ run_button.click(fn=predict,
94
+ inputs=[image, score_threshold],
95
+ outputs=result,
96
+ api_name='predict')
97
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }