hysts HF staff commited on
Commit
46d0a90
1 Parent(s): ef20843

Return a list of predicted labels

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -7,6 +7,7 @@ import functools
7
  import os
8
  import pathlib
9
  import tarfile
 
10
 
11
  import deepdanbooru as dd
12
  import gradio as gr
@@ -64,7 +65,8 @@ def load_labels() -> list[str]:
64
 
65
 
66
  def predict(image: PIL.Image.Image, score_threshold: float,
67
- model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
 
68
  _, height, width, _ = model.input_shape
69
  image = np.asarray(image)
70
  image = tf.image.resize(image,
@@ -81,7 +83,14 @@ def predict(image: PIL.Image.Image, score_threshold: float,
81
  if prob < score_threshold:
82
  continue
83
  res[label] = prob
84
- return res
 
 
 
 
 
 
 
85
 
86
 
87
  def main():
@@ -106,7 +115,10 @@ def main():
106
  value=args.score_threshold,
107
  label='Score Threshold'),
108
  ],
109
- gr.Label(label='Output'),
 
 
 
110
  examples=examples,
111
  title=TITLE,
112
  description=DESCRIPTION,
 
7
  import os
8
  import pathlib
9
  import tarfile
10
+ import tempfile
11
 
12
  import deepdanbooru as dd
13
  import gradio as gr
 
65
 
66
 
67
  def predict(image: PIL.Image.Image, score_threshold: float,
68
+ model: tf.keras.Model,
69
+ labels: list[str]) -> tuple[dict[str, float], str]:
70
  _, height, width, _ = model.input_shape
71
  image = np.asarray(image)
72
  image = tf.image.resize(image,
 
83
  if prob < score_threshold:
84
  continue
85
  res[label] = prob
86
+
87
+ sorted_preds = sorted(res.items(), key=lambda x: -x[1])
88
+ out_path = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
89
+ with open(out_path.name, 'w') as f:
90
+ for key, _ in sorted_preds:
91
+ f.write(f'{key}\n')
92
+
93
+ return res, out_path.name
94
 
95
 
96
  def main():
 
115
  value=args.score_threshold,
116
  label='Score Threshold'),
117
  ],
118
+ [
119
+ gr.Label(label='Output'),
120
+ gr.File(label='Tag List'),
121
+ ],
122
  examples=examples,
123
  title=TITLE,
124
  description=DESCRIPTION,