MonkeyJuice commited on
Commit
7c078a3
1 Parent(s): 8cdd8e4

split code

Browse files
Files changed (2) hide show
  1. app.py +4 -46
  2. genTag.py +55 -0
app.py CHANGED
@@ -2,60 +2,18 @@
2
 
3
  from __future__ import annotations
4
 
5
- import deepdanbooru as dd
6
  import gradio as gr
7
- import huggingface_hub
8
- import numpy as np
9
  import PIL.Image
10
- import tensorflow as tf
11
-
12
- def load_model() -> tf.keras.Model:
13
- path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
14
- 'model-resnet_custom_v3.h5')
15
- model = tf.keras.models.load_model(path)
16
- return model
17
-
18
-
19
- def load_labels() -> list[str]:
20
- path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
21
- 'tags.txt')
22
- with open(path) as f:
23
- labels = [line.strip() for line in f.readlines()]
24
- return labels
25
-
26
-
27
- model = load_model()
28
- labels = load_labels()
29
-
30
 
31
  def predict(image: PIL.Image.Image, score_threshold: float):
32
- _, height, width, _ = model.input_shape
33
- image = np.asarray(image)
34
- image = tf.image.resize(image,
35
- size=(height, width),
36
- method=tf.image.ResizeMethod.AREA,
37
- preserve_aspect_ratio=True)
38
- image = image.numpy()
39
- image = dd.image.transform_and_pad_image(image, width, height)
40
- image = image / 255.
41
- probs = model.predict(image[None, ...])[0]
42
- probs = probs.astype(float)
43
-
44
- indices = np.argsort(probs)[::-1]
45
- result_all = dict()
46
- result_threshold = dict()
47
  result_html = ''
48
- for index in indices:
49
- label = labels[index]
50
- prob = probs[index]
51
- result_all[label] = prob
52
- if prob < score_threshold:
53
- break
54
  result_html = result_html + '<p class="m5dd_list use"><span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
55
- result_threshold[label] = prob
56
  result_text = ', '.join(result_threshold.keys())
57
  result_text = '<div id="m5dd_result">' + str(result_text) + '</div>'
58
- result_html = '<div>' + str(result_html) + '</div>'
59
  return result_html, result_text
60
 
61
  js = """
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import gradio as gr
 
 
6
  import PIL.Image
7
+ from genTag import genTag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def predict(image: PIL.Image.Image, score_threshold: float):
10
+ result_threshold = genTag(image, score_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  result_html = ''
12
+ for label, prob in result_threshold.items():
 
 
 
 
 
13
  result_html = result_html + '<p class="m5dd_list use"><span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
14
+ result_html = '<div>' + str(result_html) + '</div>'
15
  result_text = ', '.join(result_threshold.keys())
16
  result_text = '<div id="m5dd_result">' + str(result_text) + '</div>'
 
17
  return result_html, result_text
18
 
19
  js = """
genTag.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import deepdanbooru as dd
6
+ import huggingface_hub
7
+ import numpy as np
8
+ import PIL.Image
9
+ import tensorflow as tf
10
+
11
+ def load_model() -> tf.keras.Model:
12
+ path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
13
+ 'model-resnet_custom_v3.h5')
14
+ model = tf.keras.models.load_model(path)
15
+ return model
16
+
17
+
18
+ def load_labels() -> list[str]:
19
+ path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
20
+ 'tags.txt')
21
+ with open(path) as f:
22
+ labels = [line.strip() for line in f.readlines()]
23
+ return labels
24
+
25
+
26
+ model = load_model()
27
+ labels = load_labels()
28
+
29
+
30
+ def genTag(image: PIL.Image.Image, score_threshold: float):
31
+ _, height, width, _ = model.input_shape
32
+ image = np.asarray(image)
33
+ image = tf.image.resize(image,
34
+ size=(height, width),
35
+ method=tf.image.ResizeMethod.AREA,
36
+ preserve_aspect_ratio=True)
37
+ image = image.numpy()
38
+ image = dd.image.transform_and_pad_image(image, width, height)
39
+ image = image / 255.
40
+ probs = model.predict(image[None, ...])[0]
41
+ probs = probs.astype(float)
42
+
43
+ indices = np.argsort(probs)[::-1]
44
+ result_all = dict()
45
+ result_threshold = dict()
46
+ result_html = ''
47
+ for index in indices:
48
+ label = labels[index]
49
+ prob = probs[index]
50
+ result_all[label] = prob
51
+ if prob < score_threshold:
52
+ break
53
+ result_threshold[label] = prob
54
+
55
+ return result_threshold