hysts HF staff commited on
Commit
e9c5f95
1 Parent(s): fb59cb8
Files changed (4) hide show
  1. .gitmodules +0 -3
  2. DeepDanbooru +0 -1
  3. app.py +32 -40
  4. requirements.txt +1 -0
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "DeepDanbooru"]
2
- path = DeepDanbooru
3
- url = https://github.com/KichangKim/DeepDanbooru
 
 
 
 
DeepDanbooru DELETED
@@ -1 +0,0 @@
1
- Subproject commit 92ba0b56be5eed0037e3f067bb9867f5ac691647
 
 
app.py CHANGED
@@ -7,19 +7,12 @@ import functools
7
  import os
8
  import pathlib
9
  import subprocess
10
- import sys
11
- import urllib
12
- import zipfile
13
- from typing import Callable
14
 
15
  # workaround for https://github.com/gradio-app/gradio/issues/483
16
  command = 'pip install -U gradio==2.7.0'
17
  subprocess.call(command.split())
18
 
19
- command = 'pip install -r DeepDanbooru/requirements.txt'
20
- subprocess.call(command.split())
21
- sys.path.insert(0, 'DeepDanbooru')
22
-
23
  import deepdanbooru as dd
24
  import gradio as gr
25
  import huggingface_hub
@@ -29,9 +22,9 @@ import tensorflow as tf
29
 
30
  TOKEN = os.environ['TOKEN']
31
 
32
- ZIP_PATH = 'data.zip'
33
- TAG_PATH = 'tags.txt'
34
- MODEL_PATH = 'model-resnet_custom_v3.h5'
35
 
36
 
37
  def parse_args() -> argparse.Namespace:
@@ -50,33 +43,38 @@ def parse_args() -> argparse.Namespace:
50
  return parser.parse_args()
51
 
52
 
53
- def download_sample_images() -> list[pathlib.Path]:
54
- image_dir = pathlib.Path('samples')
55
- image_dir.mkdir(exist_ok=True)
56
-
57
- dataset_repo = 'hysts/sample-images-TADNE'
58
- n_images = 36
59
- paths = []
60
- for index in range(n_images):
61
  path = huggingface_hub.hf_hub_download(dataset_repo,
62
- f'{index:02d}.jpg',
63
  repo_type='dataset',
64
- cache_dir=image_dir.as_posix(),
65
  use_auth_token=TOKEN)
66
- paths.append(pathlib.Path(path))
67
- return paths
 
68
 
69
 
70
- def download_model_data() -> None:
71
- url = 'https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20200915-sgd-e30/deepdanbooru-v3-20200915-sgd-e30.zip'
72
- urllib.request.urlretrieve(url, ZIP_PATH)
73
- with zipfile.ZipFile(ZIP_PATH) as f:
74
- f.extract(TAG_PATH)
75
- f.extract(MODEL_PATH)
76
 
77
 
78
- def predict(image: PIL.Image.Image, score_threshold: float, model,
79
- labels: list[str]) -> dict[str, float]:
 
 
 
 
 
 
 
 
 
80
  _, height, width, _ = model.input_shape
81
  image = np.asarray(image)
82
  image = tf.image.resize(image,
@@ -101,18 +99,12 @@ def main():
101
 
102
  args = parse_args()
103
 
104
- image_paths = download_sample_images()
105
  examples = [[path.as_posix(), args.score_threshold]
106
  for path in image_paths]
107
 
108
- zip_path = pathlib.Path(ZIP_PATH)
109
- if not zip_path.exists():
110
- download_model_data()
111
-
112
- model = tf.keras.models.load_model(MODEL_PATH)
113
-
114
- with open(TAG_PATH) as f:
115
- labels = [line.strip() for line in f.readlines()]
116
 
117
  func = functools.partial(predict, model=model, labels=labels)
118
  func = functools.update_wrapper(func, predict)
 
7
  import os
8
  import pathlib
9
  import subprocess
10
+ import tarfile
 
 
 
11
 
12
  # workaround for https://github.com/gradio-app/gradio/issues/483
13
  command = 'pip install -U gradio==2.7.0'
14
  subprocess.call(command.split())
15
 
 
 
 
 
16
  import deepdanbooru as dd
17
  import gradio as gr
18
  import huggingface_hub
 
22
 
23
  TOKEN = os.environ['TOKEN']
24
 
25
+ MODEL_REPO = 'hysts/DeepDanbooru'
26
+ MODEL_FILENAME = 'model-resnet_custom_v3.h5'
27
+ LABEL_FILENAME = 'tags.txt'
28
 
29
 
30
  def parse_args() -> argparse.Namespace:
 
43
  return parser.parse_args()
44
 
45
 
46
+ def load_sample_image_paths() -> list[pathlib.Path]:
47
+ image_dir = pathlib.Path('images')
48
+ if not image_dir.exists():
49
+ dataset_repo = 'hysts/sample-images-TADNE'
 
 
 
 
50
  path = huggingface_hub.hf_hub_download(dataset_repo,
51
+ 'images.tar.gz',
52
  repo_type='dataset',
 
53
  use_auth_token=TOKEN)
54
+ with tarfile.open(path) as f:
55
+ f.extractall()
56
+ return sorted(image_dir.glob('*'))
57
 
58
 
59
+ def load_model() -> tf.keras.Model:
60
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
61
+ MODEL_FILENAME,
62
+ use_auth_token=TOKEN)
63
+ model = tf.keras.models.load_model(path)
64
+ return model
65
 
66
 
67
+ def load_labels() -> list[str]:
68
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
69
+ LABEL_FILENAME,
70
+ use_auth_token=TOKEN)
71
+ with open(path) as f:
72
+ labels = [line.strip() for line in f.readlines()]
73
+ return labels
74
+
75
+
76
+ def predict(image: PIL.Image.Image, score_threshold: float,
77
+ model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
78
  _, height, width, _ = model.input_shape
79
  image = np.asarray(image)
80
  image = tf.image.resize(image,
 
99
 
100
  args = parse_args()
101
 
102
+ image_paths = load_sample_image_paths()
103
  examples = [[path.as_posix(), args.score_threshold]
104
  for path in image_paths]
105
 
106
+ model = load_model()
107
+ labels = load_labels()
 
 
 
 
 
 
108
 
109
  func = functools.partial(predict, model=model, labels=labels)
110
  func = functools.update_wrapper(func, predict)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  pillow>=9.0.0
2
  tensorflow>=2.7.0
 
 
1
  pillow>=9.0.0
2
  tensorflow>=2.7.0
3
+ git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru