Spaces:
Running
Running
Revert "Return a list of predicted labels"
Browse filesThis reverts commit 72d940356accfca18261d089a669501e0c0ff778.
app.py
CHANGED
@@ -7,7 +7,6 @@ import functools
|
|
7 |
import os
|
8 |
import pathlib
|
9 |
import tarfile
|
10 |
-
import tempfile
|
11 |
|
12 |
import deepdanbooru as dd
|
13 |
import gradio as gr
|
@@ -65,8 +64,7 @@ def load_labels() -> list[str]:
|
|
65 |
|
66 |
|
67 |
def predict(image: PIL.Image.Image, score_threshold: float,
|
68 |
-
model: tf.keras.Model,
|
69 |
-
labels: list[str]) -> tuple[dict[str, float], str]:
|
70 |
_, height, width, _ = model.input_shape
|
71 |
image = np.asarray(image)
|
72 |
image = tf.image.resize(image,
|
@@ -83,14 +81,7 @@ def predict(image: PIL.Image.Image, score_threshold: float,
|
|
83 |
if prob < score_threshold:
|
84 |
continue
|
85 |
res[label] = prob
|
86 |
-
|
87 |
-
sorted_preds = sorted(res.items(), key=lambda x: -x[1])
|
88 |
-
out_path = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
89 |
-
with open(out_path.name, 'w') as f:
|
90 |
-
for key, _ in sorted_preds:
|
91 |
-
f.write(f'{key}\n')
|
92 |
-
|
93 |
-
return res, out_path.name
|
94 |
|
95 |
|
96 |
def main():
|
@@ -115,10 +106,7 @@ def main():
|
|
115 |
value=args.score_threshold,
|
116 |
label='Score Threshold'),
|
117 |
],
|
118 |
-
|
119 |
-
gr.Label(label='Output'),
|
120 |
-
gr.File(label='Tag List'),
|
121 |
-
],
|
122 |
examples=examples,
|
123 |
title=TITLE,
|
124 |
description=DESCRIPTION,
|
|
|
7 |
import os
|
8 |
import pathlib
|
9 |
import tarfile
|
|
|
10 |
|
11 |
import deepdanbooru as dd
|
12 |
import gradio as gr
|
|
|
64 |
|
65 |
|
66 |
def predict(image: PIL.Image.Image, score_threshold: float,
|
67 |
+
model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
|
|
|
68 |
_, height, width, _ = model.input_shape
|
69 |
image = np.asarray(image)
|
70 |
image = tf.image.resize(image,
|
|
|
81 |
if prob < score_threshold:
|
82 |
continue
|
83 |
res[label] = prob
|
84 |
+
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
|
87 |
def main():
|
|
|
106 |
value=args.score_threshold,
|
107 |
label='Score Threshold'),
|
108 |
],
|
109 |
+
gr.Label(label='Output'),
|
|
|
|
|
|
|
110 |
examples=examples,
|
111 |
title=TITLE,
|
112 |
description=DESCRIPTION,
|