Spaces:
Running
Running
yangzhitao
commited on
Commit
·
1e0ed83
1
Parent(s):
c94136a
refactor: improve benchmark handling in create_submit_tab function by restructuring input processing and enhancing data validation
Browse files
app.py
CHANGED
|
@@ -1,4 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import threading
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import gradio.components as grc
|
|
@@ -139,8 +145,6 @@ def search_models_in_dataframe(search_text: str, df: pd.DataFrame) -> pd.DataFra
|
|
| 139 |
return df
|
| 140 |
|
| 141 |
# 分割逗号,去除空白并转换为小写用于匹配
|
| 142 |
-
import re
|
| 143 |
-
|
| 144 |
keywords = [keyword.strip().lower() for keyword in search_text.split(',') if keyword.strip()]
|
| 145 |
if not keywords:
|
| 146 |
return df
|
|
@@ -493,8 +497,6 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 493 |
if file is None:
|
| 494 |
return ""
|
| 495 |
try:
|
| 496 |
-
import json
|
| 497 |
-
|
| 498 |
# file 是文件路径字符串(当 type="filepath" 时)
|
| 499 |
file_path = file if isinstance(file, str) else file.name
|
| 500 |
with open(file_path, encoding='utf-8') as f:
|
|
@@ -532,12 +534,10 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 532 |
model_name: str,
|
| 533 |
revision: str,
|
| 534 |
precision: str,
|
| 535 |
-
benchmark_checkbox_values: list,
|
| 536 |
-
benchmark_result_values: list,
|
| 537 |
) -> str:
|
| 538 |
"""Build JSON from form inputs"""
|
| 539 |
-
import json
|
| 540 |
-
|
| 541 |
if not model_name or not model_name.strip():
|
| 542 |
raise ValueError("Model name is required")
|
| 543 |
|
|
@@ -549,7 +549,7 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 549 |
"model_name": model_name,
|
| 550 |
"model_key": model_key,
|
| 551 |
"model_dtype": f"torch.{precision}" if precision else None,
|
| 552 |
-
"model_sha": revision or "main"
|
| 553 |
"model_args": None,
|
| 554 |
}
|
| 555 |
|
|
@@ -568,6 +568,21 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 568 |
|
| 569 |
return json.dumps({"config": config, "results": results}, indent=2, ensure_ascii=False)
|
| 570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
def submit_with_form_or_json(
|
| 572 |
model: str,
|
| 573 |
base_model: str,
|
|
@@ -577,12 +592,11 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 577 |
model_type: str,
|
| 578 |
json_str: str,
|
| 579 |
commit_message: str,
|
| 580 |
-
oauth_profile: gr.OAuthProfile
|
| 581 |
-
|
|
|
|
| 582 |
):
|
| 583 |
"""Submit with either form data or JSON"""
|
| 584 |
-
import json
|
| 585 |
-
|
| 586 |
# Check if user is logged in
|
| 587 |
if oauth_profile is None:
|
| 588 |
return styled_error("Please log in before submitting.")
|
|
@@ -604,12 +618,25 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 604 |
# Build JSON from form
|
| 605 |
# benchmark_values contains pairs of (checkbox_value, result_value) for each benchmark
|
| 606 |
benchmarks_list = get_benchmarks()
|
| 607 |
-
if len(benchmark_values) != len(benchmarks_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
return styled_error("Invalid benchmark form data. Please check your inputs.")
|
| 609 |
|
| 610 |
# Split into checkbox values and result values
|
| 611 |
-
benchmark_checkbox_values
|
| 612 |
-
benchmark_result_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
try:
|
| 615 |
final_json = build_json_from_form(
|
|
@@ -664,24 +691,28 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
|
|
| 664 |
submission_result = gr.Markdown()
|
| 665 |
|
| 666 |
# Collect all inputs for submission
|
| 667 |
-
all_inputs = [
|
| 668 |
-
model_name_textbox,
|
| 669 |
-
base_model_name_textbox,
|
| 670 |
-
revision_name_textbox,
|
| 671 |
-
precision,
|
| 672 |
-
weight_type,
|
| 673 |
-
model_type,
|
| 674 |
-
json_str,
|
| 675 |
-
commit_textbox,
|
| 676 |
-
login_button, # oauth_profile must be before *benchmark_values
|
| 677 |
-
]
|
| 678 |
# Add benchmark form inputs (these will be captured by *benchmark_values)
|
|
|
|
| 679 |
for _, checkbox, result_input in benchmark_results_form:
|
| 680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
submit_button.click(
|
| 683 |
-
fn=submit_with_form_or_json,
|
| 684 |
-
inputs=all_inputs,
|
| 685 |
outputs=submission_result,
|
| 686 |
)
|
| 687 |
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import sys
|
| 4 |
import threading
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
from functools import partial
|
| 7 |
+
from textwrap import dedent
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
import gradio.components as grc
|
|
|
|
| 145 |
return df
|
| 146 |
|
| 147 |
# 分割逗号,去除空白并转换为小写用于匹配
|
|
|
|
|
|
|
| 148 |
keywords = [keyword.strip().lower() for keyword in search_text.split(',') if keyword.strip()]
|
| 149 |
if not keywords:
|
| 150 |
return df
|
|
|
|
| 497 |
if file is None:
|
| 498 |
return ""
|
| 499 |
try:
|
|
|
|
|
|
|
| 500 |
# file 是文件路径字符串(当 type="filepath" 时)
|
| 501 |
file_path = file if isinstance(file, str) else file.name
|
| 502 |
with open(file_path, encoding='utf-8') as f:
|
|
|
|
| 534 |
model_name: str,
|
| 535 |
revision: str,
|
| 536 |
precision: str,
|
| 537 |
+
benchmark_checkbox_values: list[bool],
|
| 538 |
+
benchmark_result_values: list[float],
|
| 539 |
) -> str:
|
| 540 |
"""Build JSON from form inputs"""
|
|
|
|
|
|
|
| 541 |
if not model_name or not model_name.strip():
|
| 542 |
raise ValueError("Model name is required")
|
| 543 |
|
|
|
|
| 549 |
"model_name": model_name,
|
| 550 |
"model_key": model_key,
|
| 551 |
"model_dtype": f"torch.{precision}" if precision else None,
|
| 552 |
+
"model_sha": revision or None, # None means "main"
|
| 553 |
"model_args": None,
|
| 554 |
}
|
| 555 |
|
|
|
|
| 568 |
|
| 569 |
return json.dumps({"config": config, "results": results}, indent=2, ensure_ascii=False)
|
| 570 |
|
| 571 |
+
SubmitWithFormOrJsonInputs = namedtuple(
|
| 572 |
+
"SubmitWithFormOrJsonInputs",
|
| 573 |
+
[
|
| 574 |
+
"model",
|
| 575 |
+
"base_model",
|
| 576 |
+
"revision",
|
| 577 |
+
"precision",
|
| 578 |
+
"weight_type",
|
| 579 |
+
"model_type",
|
| 580 |
+
"json_str",
|
| 581 |
+
"commit_message",
|
| 582 |
+
# "oauth_profile",
|
| 583 |
+
],
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
def submit_with_form_or_json(
|
| 587 |
model: str,
|
| 588 |
base_model: str,
|
|
|
|
| 592 |
model_type: str,
|
| 593 |
json_str: str,
|
| 594 |
commit_message: str,
|
| 595 |
+
oauth_profile: gr.OAuthProfile,
|
| 596 |
+
*,
|
| 597 |
+
benchmark_values: list[bool | float],
|
| 598 |
):
|
| 599 |
"""Submit with either form data or JSON"""
|
|
|
|
|
|
|
| 600 |
# Check if user is logged in
|
| 601 |
if oauth_profile is None:
|
| 602 |
return styled_error("Please log in before submitting.")
|
|
|
|
| 618 |
# Build JSON from form
|
| 619 |
# benchmark_values contains pairs of (checkbox_value, result_value) for each benchmark
|
| 620 |
benchmarks_list = get_benchmarks()
|
| 621 |
+
if len(benchmark_values) != len(benchmarks_list):
|
| 622 |
+
print(
|
| 623 |
+
dedent(f"""
|
| 624 |
+
Invalid benchmark form data. Please check your inputs.
|
| 625 |
+
* benchmarks_list: {benchmarks_list!r}
|
| 626 |
+
* benchmark_values: {benchmark_values!r}
|
| 627 |
+
"""),
|
| 628 |
+
file=sys.stderr,
|
| 629 |
+
)
|
| 630 |
return styled_error("Invalid benchmark form data. Please check your inputs.")
|
| 631 |
|
| 632 |
# Split into checkbox values and result values
|
| 633 |
+
benchmark_checkbox_values: list[bool] = []
|
| 634 |
+
benchmark_result_values: list[float] = []
|
| 635 |
+
for i, val in enumerate(benchmark_values):
|
| 636 |
+
if i % 2 == 0:
|
| 637 |
+
benchmark_checkbox_values.append(bool(val))
|
| 638 |
+
else:
|
| 639 |
+
benchmark_result_values.append(float(val)) # pyright: ignore[reportArgumentType]
|
| 640 |
|
| 641 |
try:
|
| 642 |
final_json = build_json_from_form(
|
|
|
|
| 691 |
submission_result = gr.Markdown()
|
| 692 |
|
| 693 |
# Collect all inputs for submission
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
# Add benchmark form inputs (these will be captured by *benchmark_values)
|
| 695 |
+
benchmark_values = []
|
| 696 |
for _, checkbox, result_input in benchmark_results_form:
|
| 697 |
+
benchmark_values.extend([checkbox.value, result_input.value])
|
| 698 |
+
|
| 699 |
+
all_inputs = list(
|
| 700 |
+
SubmitWithFormOrJsonInputs(
|
| 701 |
+
model=model_name_textbox,
|
| 702 |
+
base_model=base_model_name_textbox,
|
| 703 |
+
revision=revision_name_textbox,
|
| 704 |
+
precision=precision,
|
| 705 |
+
weight_type=weight_type,
|
| 706 |
+
model_type=model_type,
|
| 707 |
+
json_str=json_str,
|
| 708 |
+
commit_message=commit_textbox,
|
| 709 |
+
# oauth_profile=login_button,
|
| 710 |
+
)
|
| 711 |
+
)
|
| 712 |
|
| 713 |
submit_button.click(
|
| 714 |
+
fn=partial(submit_with_form_or_json, benchmark_values=benchmark_values),
|
| 715 |
+
inputs=list(all_inputs),
|
| 716 |
outputs=submission_result,
|
| 717 |
)
|
| 718 |
|