Commit
Β·
7e38559
1
Parent(s):
5cf40e5
adding option to show all submissions
Browse files
app.py
CHANGED
|
@@ -107,14 +107,23 @@ def make_heatmap(results, label="generated", symbol="π€"):
|
|
| 107 |
|
| 108 |
return chart
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
@st.cache_data
|
| 112 |
-
def
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
|
| 116 |
-
|
| 117 |
-
rocs = rocs[rocs["submission_id"].isin(submission_cols)]
|
| 118 |
|
| 119 |
roc_chart = alt.Chart(rocs).mark_line().encode(x="fpr", y="tpr", color="team:N", detail="submission_id:N")
|
| 120 |
|
|
@@ -380,6 +389,7 @@ def make_roc(results):
|
|
| 380 |
size=alt.Size(
|
| 381 |
"total_time:Q", title="π Inference Time", scale=alt.Scale(rangeMin=100)
|
| 382 |
), # Size by quantitative field
|
|
|
|
| 383 |
)
|
| 384 |
.properties(width=400, height=400, title="Detection vs False Alarm vs Inference Time")
|
| 385 |
)
|
|
@@ -400,7 +410,7 @@ def make_acc(results):
|
|
| 400 |
alt.Chart(results)
|
| 401 |
.mark_circle(size=200)
|
| 402 |
.encode(
|
| 403 |
-
x=alt.X("total_time:Q", title="π Inference Time", scale=alt.Scale(
|
| 404 |
y=alt.Y(
|
| 405 |
"balanced_accuracy:Q",
|
| 406 |
title="Balanced Accuracy",
|
|
@@ -411,7 +421,7 @@ def make_acc(results):
|
|
| 411 |
.properties(width=400, height=400, title="Inference Time vs Balanced Accuracy")
|
| 412 |
)
|
| 413 |
diag_line = (
|
| 414 |
-
alt.Chart(pd.DataFrame(dict(t=[
|
| 415 |
.mark_line(color="lightgray", strokeDash=[8, 4])
|
| 416 |
.encode(x="t", y="y")
|
| 417 |
)
|
|
@@ -431,15 +441,31 @@ def get_heatmaps(temp):
|
|
| 431 |
|
| 432 |
|
| 433 |
def make_plots_for_task(task, split, best_only):
|
|
|
|
|
|
|
| 434 |
results = load_results(task, best_only=best_only)
|
| 435 |
temp = results[f"{split}_score"].reset_index()
|
|
|
|
|
|
|
| 436 |
|
| 437 |
t1, t2 = st.tabs(["Tables", "Charts"])
|
| 438 |
with t1:
|
| 439 |
show_leaderboard(results, task)
|
| 440 |
|
| 441 |
with t2:
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
roc_scatter = make_roc(temp)
|
| 444 |
acc_vs_time = make_acc(temp)
|
| 445 |
|
|
@@ -447,12 +473,15 @@ def make_plots_for_task(task, split, best_only):
|
|
| 447 |
full_curves = st.toggle("Full curve", value=True, key=f"all curves {task}")
|
| 448 |
|
| 449 |
if full_curves:
|
| 450 |
-
roc_scatter = make_roc_curves(task, temp["submission_id"].values.tolist()
|
| 451 |
|
| 452 |
st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
|
| 453 |
else:
|
| 454 |
st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
|
| 455 |
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
updated = get_updated_time()
|
| 458 |
st.markdown(updated)
|
|
|
|
| 107 |
|
| 108 |
return chart
|
| 109 |
|
| 110 |
+
@st.cache_data
|
| 111 |
+
def load_roc_file(task, submission_ids):
|
| 112 |
+
rocs = pd.read_csv(f"{results_path}/{task}_rocs.csv")
|
| 113 |
+
# if best_only:
|
| 114 |
+
rocs = rocs[rocs["submission_id"].isin(submission_ids)]
|
| 115 |
+
return rocs
|
| 116 |
+
@st.cache_data
|
| 117 |
+
def get_unique_teams(teams):
|
| 118 |
+
return teams.unique().tolist()
|
| 119 |
|
| 120 |
@st.cache_data
|
| 121 |
+
def filter_teams(temp, selected_team):
|
| 122 |
+
return temp.query(f"team=='{selected_team}'")
|
| 123 |
|
| 124 |
+
def make_roc_curves(task, submission_ids):
|
| 125 |
|
| 126 |
+
rocs = load_roc_file(task, submission_ids)
|
|
|
|
| 127 |
|
| 128 |
roc_chart = alt.Chart(rocs).mark_line().encode(x="fpr", y="tpr", color="team:N", detail="submission_id:N")
|
| 129 |
|
|
|
|
| 389 |
size=alt.Size(
|
| 390 |
"total_time:Q", title="π Inference Time", scale=alt.Scale(rangeMin=100)
|
| 391 |
), # Size by quantitative field
|
| 392 |
+
detail=["submission_id","auc","balanced_accuracy"]
|
| 393 |
)
|
| 394 |
.properties(width=400, height=400, title="Detection vs False Alarm vs Inference Time")
|
| 395 |
)
|
|
|
|
| 410 |
alt.Chart(results)
|
| 411 |
.mark_circle(size=200)
|
| 412 |
.encode(
|
| 413 |
+
x=alt.X("total_time:Q", title="π Inference Time", scale=alt.Scale(type = "log")),
|
| 414 |
y=alt.Y(
|
| 415 |
"balanced_accuracy:Q",
|
| 416 |
title="Balanced Accuracy",
|
|
|
|
| 421 |
.properties(width=400, height=400, title="Inference Time vs Balanced Accuracy")
|
| 422 |
)
|
| 423 |
diag_line = (
|
| 424 |
+
alt.Chart(pd.DataFrame(dict(t=[100, results["total_time"].max()], y=[0.5, 0.5])))
|
| 425 |
.mark_line(color="lightgray", strokeDash=[8, 4])
|
| 426 |
.encode(x="t", y="y")
|
| 427 |
)
|
|
|
|
| 441 |
|
| 442 |
|
| 443 |
def make_plots_for_task(task, split, best_only):
|
| 444 |
+
|
| 445 |
+
|
| 446 |
results = load_results(task, best_only=best_only)
|
| 447 |
temp = results[f"{split}_score"].reset_index()
|
| 448 |
+
teams = get_unique_teams(temp["team"])
|
| 449 |
+
|
| 450 |
|
| 451 |
t1, t2 = st.tabs(["Tables", "Charts"])
|
| 452 |
with t1:
|
| 453 |
show_leaderboard(results, task)
|
| 454 |
|
| 455 |
with t2:
|
| 456 |
+
# st.write(temp)
|
| 457 |
+
if split == "private":
|
| 458 |
+
best_only = st.toggle("Best Only", value=True, key = f"best only {task}")
|
| 459 |
+
if not best_only:
|
| 460 |
+
results = load_results(task, best_only=best_only)
|
| 461 |
+
temp = results[f"{split}_score"].reset_index()
|
| 462 |
+
selected_team = st.pills("Team",["ALL"] + teams, key = f"teams {task}", default="ALL")
|
| 463 |
+
if not selected_team:
|
| 464 |
+
selected_team = "ALL"
|
| 465 |
+
if selected_team != "ALL":
|
| 466 |
+
temp = filter_teams(temp, selected_team)
|
| 467 |
+
|
| 468 |
+
# with st.spinner("making plots...", show_time=True):
|
| 469 |
roc_scatter = make_roc(temp)
|
| 470 |
acc_vs_time = make_acc(temp)
|
| 471 |
|
|
|
|
| 473 |
full_curves = st.toggle("Full curve", value=True, key=f"all curves {task}")
|
| 474 |
|
| 475 |
if full_curves:
|
| 476 |
+
roc_scatter = make_roc_curves(task, temp["submission_id"].values.tolist()) + roc_scatter
|
| 477 |
|
| 478 |
st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
|
| 479 |
else:
|
| 480 |
st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
|
| 481 |
|
| 482 |
+
st.info(f"loading {temp["submission_id"].nunique()} submissioms")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
|
| 486 |
updated = get_updated_time()
|
| 487 |
st.markdown(updated)
|
test.sh
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
HF_TOKEN=$(cat ~/.cache/huggingface/token) streamlit run app.py
|
|
|
|
|
|
| 1 |
+
# HF_TOKEN=$(cat ~/.cache/huggingface/token) streamlit run app.py
|
| 2 |
+
HF_TOKEN=test streamlit run app.py
|
utils.py
CHANGED
|
@@ -137,8 +137,12 @@ def extract_roc(results: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 137 |
if __name__ == "__main__":
|
| 138 |
|
| 139 |
## Download data
|
| 140 |
-
spaces: List[str] = ["safe-challenge/video-challenge-pilot-config", "safe-challenge/video-challenge-task-1-config"]
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
## Loop
|
| 144 |
for space in spaces:
|
|
@@ -155,7 +159,26 @@ if __name__ == "__main__":
|
|
| 155 |
|
| 156 |
## Loop and save by team
|
| 157 |
public, private, rocs = [], [], []
|
| 158 |
-
for team_id, submission_set in submissions.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
results = compute_metric_per_team(solution_df=solutions_df, team_submissions=submission_set)
|
| 160 |
public_results = {
|
| 161 |
key: prep_public(value["public_score"]) for key, value in results.items() if key in team_submissions
|
|
@@ -289,9 +312,9 @@ if __name__ == "__main__":
|
|
| 289 |
private.append(private_df)
|
| 290 |
|
| 291 |
## Save as csvs
|
| 292 |
-
public = pd.concat(public, axis=0).sort_values(by="balanced_accuracy", ascending=False)
|
| 293 |
-
private = pd.concat(private, axis=0).sort_values(by="balanced_accuracy", ascending=False)
|
| 294 |
-
rocs = pd.concat(rocs, axis=0).explode(["tpr", "fpr", "threshold"], ignore_index=True)
|
| 295 |
public.to_csv(
|
| 296 |
Path("competition_cache") / "cached_results" / f"{str(local_dir).split('/')[-1]}_public_score.csv",
|
| 297 |
index=False,
|
|
|
|
| 137 |
if __name__ == "__main__":
|
| 138 |
|
| 139 |
## Download data
|
| 140 |
+
# spaces: List[str] = ["safe-challenge/video-challenge-pilot-config", "safe-challenge/video-challenge-task-1-config"]
|
| 141 |
+
spaces: List[str] = ["safe-challenge/video-challenge-task-1-config"]
|
| 142 |
+
|
| 143 |
+
# download_competition_data(competition_names=spaces)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
|
| 147 |
## Loop
|
| 148 |
for space in spaces:
|
|
|
|
| 159 |
|
| 160 |
## Loop and save by team
|
| 161 |
public, private, rocs = [], [], []
|
| 162 |
+
# for team_id, submission_set in submissions.items():
|
| 163 |
+
for team_id, submission_set_ids in submission_summaries.query("status_reason=='SUCCESS'").groupby("team_id")["submission_id"]:
|
| 164 |
+
### lets check if we have the solution csvs
|
| 165 |
+
submission_set = submissions[team_id]
|
| 166 |
+
submission_set_ids_from_csvs = set(submission_set.keys())
|
| 167 |
+
submission_set_ids = set(submission_set_ids)
|
| 168 |
+
union = submission_set_ids | submission_set_ids_from_csvs
|
| 169 |
+
|
| 170 |
+
if not (submission_set_ids.issubset(submission_set_ids_from_csvs)):
|
| 171 |
+
# intersction = set(submission_set_ids) & submission_set_ids_from_csvs
|
| 172 |
+
missing = union - submission_set_ids_from_csvs
|
| 173 |
+
print(f"not all submission csv files found for {team_id}, missing {len(missing)}")
|
| 174 |
+
|
| 175 |
+
if submission_set_ids != submission_set_ids_from_csvs:
|
| 176 |
+
extra = union - submission_set_ids
|
| 177 |
+
print(f"extra {len(extra)} submissions in csvs than in summary file for team {team_id}")
|
| 178 |
+
print(f"dropping {extra}")
|
| 179 |
+
for submission_id in extra:
|
| 180 |
+
submission_set.pop(submission_id)
|
| 181 |
+
|
| 182 |
results = compute_metric_per_team(solution_df=solutions_df, team_submissions=submission_set)
|
| 183 |
public_results = {
|
| 184 |
key: prep_public(value["public_score"]) for key, value in results.items() if key in team_submissions
|
|
|
|
| 312 |
private.append(private_df)
|
| 313 |
|
| 314 |
## Save as csvs
|
| 315 |
+
public = pd.concat(public, axis=0,ignore_index=True).sort_values(by="balanced_accuracy", ascending=False)
|
| 316 |
+
private = pd.concat(private, axis=0,ignore_index=True).sort_values(by="balanced_accuracy", ascending=False)
|
| 317 |
+
rocs = pd.concat(rocs, axis=0,ignore_index=True).explode(["tpr", "fpr", "threshold"], ignore_index=True)
|
| 318 |
public.to_csv(
|
| 319 |
Path("competition_cache") / "cached_results" / f"{str(local_dir).split('/')[-1]}_public_score.csv",
|
| 320 |
index=False,
|