Spaces:
Sleeping
Sleeping
Commit
·
d21dd33
1
Parent(s):
7c398ad
fixed 95 percent choosing.
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ def is_local_machine():
|
|
27 |
|
28 |
|
29 |
if is_local_machine():
|
30 |
-
model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-
|
31 |
else:
|
32 |
model_path = "Hacker1337/distilbert-arxiv-checkpoint"
|
33 |
|
@@ -59,15 +59,14 @@ def infer(
|
|
59 |
bar_plot_dict = {}
|
60 |
total_prop = sum([prediction["score"] for prediction in predictions])
|
61 |
gained_prob = 0
|
62 |
-
for prediction in predictions:
|
63 |
bar_plot_dict[prediction["label"]] = prediction["score"]
|
64 |
-
if (gained_prob
|
65 |
label_dict[category2human[prediction["label"]]] = (
|
66 |
prediction["score"] / total_prop
|
67 |
)
|
68 |
gained_prob += prediction["score"]
|
69 |
-
|
70 |
-
if gained_prob < total_prop:
|
71 |
label_dict["Other"] = (total_prop - gained_prob) / total_prop
|
72 |
return df, label_dict
|
73 |
|
|
|
27 |
|
28 |
|
29 |
if is_local_machine():
|
30 |
+
model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2")
|
31 |
else:
|
32 |
model_path = "Hacker1337/distilbert-arxiv-checkpoint"
|
33 |
|
|
|
59 |
bar_plot_dict = {}
|
60 |
total_prop = sum([prediction["score"] for prediction in predictions])
|
61 |
gained_prob = 0
|
62 |
+
for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True):
|
63 |
bar_plot_dict[prediction["label"]] = prediction["score"]
|
64 |
+
if (gained_prob) / total_prop < target_probs_sum:
|
65 |
label_dict[category2human[prediction["label"]]] = (
|
66 |
prediction["score"] / total_prop
|
67 |
)
|
68 |
gained_prob += prediction["score"]
|
69 |
+
if gained_prob < total_prop + 1e-5:
|
|
|
70 |
label_dict["Other"] = (total_prop - gained_prob) / total_prop
|
71 |
return df, label_dict
|
72 |
|