liewchooichin commited on
Commit
0c08748
1 Parent(s): eb9bc66

Set threshold 0.4, multilabel textbox hidden

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -144,6 +144,9 @@ def get_img_array(img_path):
144
 
145
 
146
  def get_prediction(img_path):
 
 
 
147
  # check the image path
148
  print(f"Image path: {img_path}")
149
  # also display the original filename for info
@@ -156,7 +159,7 @@ def get_prediction(img_path):
156
  # binary label
157
  pred_binary = keras_binary_model(img_array, training=False)
158
  print(f"Keras binary label: {pred_binary}")
159
- if pred_binary[0][0] > 0.5:
160
  fake = "Fake"
161
  else:
162
  fake = "Real"
@@ -165,7 +168,7 @@ def get_prediction(img_path):
165
  pred_multi = keras_multi_model(img_array, training=False)
166
  print(f"Keras multi label: {pred_multi}")
167
  # Cut at the sigmoid 0.5 threshold
168
- fake_parts = np.where(pred_multi > 0.5, 1, 0)
169
  print(f"Multi label: {fake_parts}")
170
  # Format each of the fake face parts
171
  parts_message = dict()
@@ -255,7 +258,8 @@ with gr.Blocks() as demo:
255
  interactive=False, lines=2)
256
  text_3 = gr.Text(
257
  label="Multi label, Efficient net v2 B0",
258
- interactive=False, lines=7)
 
259
  """
260
  text_3 = gr.Text(label="Sashi's model",
261
  interactive=False, lines=3)
 
144
 
145
 
146
  def get_prediction(img_path):
147
+ # adjust threshold for accuracy
148
+ threshold = 0.4
149
+
150
  # check the image path
151
  print(f"Image path: {img_path}")
152
  # also display the original filename for info
 
159
  # binary label
160
  pred_binary = keras_binary_model(img_array, training=False)
161
  print(f"Keras binary label: {pred_binary}")
162
+ if pred_binary[0][0] > threshold:
163
  fake = "Fake"
164
  else:
165
  fake = "Real"
 
168
  pred_multi = keras_multi_model(img_array, training=False)
169
  print(f"Keras multi label: {pred_multi}")
170
  # Cut at the sigmoid 0.5 threshold
171
+ fake_parts = np.where(pred_multi > threshold, 1, 0)
172
  print(f"Multi label: {fake_parts}")
173
  # Format each of the fake face parts
174
  parts_message = dict()
 
258
  interactive=False, lines=2)
259
  text_3 = gr.Text(
260
  label="Multi label, Efficient net v2 B0",
261
+ interactive=False, lines=7,
262
+ visible=False)
263
  """
264
  text_3 = gr.Text(label="Sashi's model",
265
  interactive=False, lines=3)