Hacker1337 commited on
Commit
d21dd33
·
1 Parent(s): 7c398ad

fixed 95 percent choosing.

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -27,7 +27,7 @@ def is_local_machine():
27
 
28
 
29
  if is_local_machine():
30
- model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv")
31
  else:
32
  model_path = "Hacker1337/distilbert-arxiv-checkpoint"
33
 
@@ -59,15 +59,14 @@ def infer(
59
  bar_plot_dict = {}
60
  total_prop = sum([prediction["score"] for prediction in predictions])
61
  gained_prob = 0
62
- for prediction in predictions:
63
  bar_plot_dict[prediction["label"]] = prediction["score"]
64
- if (gained_prob + prediction["score"]) / total_prop < target_probs_sum:
65
  label_dict[category2human[prediction["label"]]] = (
66
  prediction["score"] / total_prop
67
  )
68
  gained_prob += prediction["score"]
69
-
70
- if gained_prob < total_prop:
71
  label_dict["Other"] = (total_prop - gained_prob) / total_prop
72
  return df, label_dict
73
 
 
27
 
28
 
29
  if is_local_machine():
30
+ model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2")
31
  else:
32
  model_path = "Hacker1337/distilbert-arxiv-checkpoint"
33
 
 
59
  bar_plot_dict = {}
60
  total_prop = sum([prediction["score"] for prediction in predictions])
61
  gained_prob = 0
62
+ for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True):
63
  bar_plot_dict[prediction["label"]] = prediction["score"]
64
+ if (gained_prob) / total_prop < target_probs_sum:
65
  label_dict[category2human[prediction["label"]]] = (
66
  prediction["score"] / total_prop
67
  )
68
  gained_prob += prediction["score"]
69
+ if gained_prob < total_prop + 1e-5:
 
70
  label_dict["Other"] = (total_prop - gained_prob) / total_prop
71
  return df, label_dict
72