Spaces:
Sleeping
Sleeping
Commit
•
0c08748
1
Parent(s):
eb9bc66
Set threshold 0.4, multilabel textbox hidden
Browse files
app.py
CHANGED
@@ -144,6 +144,9 @@ def get_img_array(img_path):
|
|
144 |
|
145 |
|
146 |
def get_prediction(img_path):
|
|
|
|
|
|
|
147 |
# check the image path
|
148 |
print(f"Image path: {img_path}")
|
149 |
# also display the original filename for info
|
@@ -156,7 +159,7 @@ def get_prediction(img_path):
|
|
156 |
# binary label
|
157 |
pred_binary = keras_binary_model(img_array, training=False)
|
158 |
print(f"Keras binary label: {pred_binary}")
|
159 |
-
|
160 |
fake = "Fake"
|
161 |
else:
|
162 |
fake = "Real"
|
@@ -165,7 +168,7 @@ def get_prediction(img_path):
|
|
165 |
pred_multi = keras_multi_model(img_array, training=False)
|
166 |
print(f"Keras multi label: {pred_multi}")
|
167 |
# Cut at the sigmoid 0.5 threshold
|
168 |
-
fake_parts = np.where(pred_multi >
|
169 |
print(f"Multi label: {fake_parts}")
|
170 |
# Format each of the fake face parts
|
171 |
parts_message = dict()
|
@@ -255,7 +258,8 @@ with gr.Blocks() as demo:
|
|
255 |
interactive=False, lines=2)
|
256 |
text_3 = gr.Text(
|
257 |
label="Multi label, Efficient net v2 B0",
|
258 |
-
interactive=False, lines=7
|
|
|
259 |
"""
|
260 |
text_3 = gr.Text(label="Sashi's model",
|
261 |
interactive=False, lines=3)
|
|
|
144 |
|
145 |
|
146 |
def get_prediction(img_path):
|
147 |
+
# adjust threshold for accuracy
|
148 |
+
threshold = 0.4
|
149 |
+
|
150 |
# check the image path
|
151 |
print(f"Image path: {img_path}")
|
152 |
# also display the original filename for info
|
|
|
159 |
# binary label
|
160 |
pred_binary = keras_binary_model(img_array, training=False)
|
161 |
print(f"Keras binary label: {pred_binary}")
|
162 |
+
if pred_binary[0][0] > threshold:
|
163 |
fake = "Fake"
|
164 |
else:
|
165 |
fake = "Real"
|
|
|
168 |
pred_multi = keras_multi_model(img_array, training=False)
|
169 |
print(f"Keras multi label: {pred_multi}")
|
170 |
# Cut at the sigmoid 0.5 threshold
|
171 |
+
fake_parts = np.where(pred_multi > threshold, 1, 0)
|
172 |
print(f"Multi label: {fake_parts}")
|
173 |
# Format each of the fake face parts
|
174 |
parts_message = dict()
|
|
|
258 |
interactive=False, lines=2)
|
259 |
text_3 = gr.Text(
|
260 |
label="Multi label, Efficient net v2 B0",
|
261 |
+
interactive=False, lines=7,
|
262 |
+
visible=False)
|
263 |
"""
|
264 |
text_3 = gr.Text(label="Sashi's model",
|
265 |
interactive=False, lines=3)
|