taxfree_python commited on
Commit
b29fd2d
1 Parent(s): 7d9cce6

Add functions to submit models

Browse files
app.py CHANGED
@@ -1,37 +1,38 @@
1
  import gradio as gr
2
 
3
- from leaderboard.dataset import load_or_initialize_leaderboard
4
  from leaderboard.submission import submit_model
5
 
6
 
7
- # リーダーボード表示
8
  def display_leaderboard():
9
- dataset = load_or_initialize_leaderboard()
10
- return dataset.to_pandas()
11
 
12
 
13
- # Gradio のコンポーネント
14
- leaderboard_component = gr.DataFrame(
15
- display_leaderboard, headers=["Model Name", "Score", "Rank"], interactive=False, label="Leaderboard"
16
- )
17
-
18
- submit_form = gr.Interface(
19
- submit_model,
20
- inputs=[gr.Textbox(label="Model Name"), gr.File(label="Model File")],
21
- outputs=gr.DataFrame(headers=["Model Name", "Score", "Rank"], interactive=False),
22
- )
23
-
24
- # Gradio アプリケーション
25
- app = gr.Blocks()
26
-
27
- with app:
28
- gr.Markdown("# human_methylation_bench_ver1")
29
 
30
  with gr.Tab("Leaderboard"):
31
- leaderboard_component.render()
 
 
 
 
 
32
 
33
  with gr.Tab("Submit Model"):
34
- submit_form.render()
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if __name__ == "__main__":
37
  app.launch()
 
1
  import gradio as gr
2
 
3
+ from leaderboard.dataset import get_leaderboard_df
4
  from leaderboard.submission import submit_model
5
 
6
 
 
7
  def display_leaderboard():
8
+ df = get_leaderboard_df()
9
+ return df
10
 
11
 
12
+ with gr.Blocks() as app:
13
+ gr.Markdown("# human_methylation_bench_ver1 Leaderboard")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  with gr.Tab("Leaderboard"):
16
+ leaderboard_df = gr.DataFrame(
17
+ value=display_leaderboard(),
18
+ headers=["Model Name", "Score (relative_error_loss)", "Rank"],
19
+ interactive=False,
20
+ label="Leaderboard",
21
+ )
22
 
23
  with gr.Tab("Submit Model"):
24
+ model_name_input = gr.Textbox(label="Model Name", placeholder="e.g. My Great Model")
25
+ model_url_input = gr.Textbox(
26
+ label="Hugging Face Model Joblib URL",
27
+ placeholder="e.g. https://huggingface.co/username/model/resolve/main/model.joblib",
28
+ )
29
+ submit_button = gr.Button("Submit")
30
+
31
+ submission_output = gr.DataFrame(
32
+ headers=["Model Name", "Score (relative_error_loss)", "Rank"], interactive=False, label="Updated Leaderboard"
33
+ )
34
+
35
+ submit_button.click(submit_model, inputs=[model_name_input, model_url_input], outputs=submission_output)
36
 
37
  if __name__ == "__main__":
38
  app.launch()
leaderboard/dataset.py CHANGED
@@ -1,25 +1,27 @@
 
 
 
1
  from datasets import Dataset, load_dataset
 
 
2
 
3
- DATASET_PATH = "leaderboard_dataset"
4
 
5
- # 初期データ
6
- INITIAL_DATA = {
7
- "Model Name": ["Baseline Model"],
8
- "Score": [0.8],
9
- "Rank": [1],
10
- }
11
 
12
 
13
- # データセットを初期化またはロード
14
- def load_or_initialize_leaderboard():
15
- try:
16
- dataset = Dataset.load_from_disk(DATASET_PATH)
17
- except FileNotFoundError:
18
- dataset = Dataset.from_dict(INITIAL_DATA)
19
- dataset.save_to_disk(DATASET_PATH)
20
- return dataset
21
 
22
 
23
- # データセットを保存
24
- def save_leaderboard(dataset):
25
- dataset.save_to_disk(DATASET_PATH)
 
 
 
1
+ import os
2
+
3
+ import pandas as pd
4
  from datasets import Dataset, load_dataset
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import login
7
 
8
+ load_dotenv()
9
 
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ RESULT_DATASET_ID = os.environ.get("RESULT_DATASET_ID", None)
12
+ if HF_TOKEN:
13
+ login(token=HF_TOKEN)
 
 
14
 
15
 
16
+ def get_leaderboard_df() -> pd.DataFrame:
17
+ # リーダーボードデータセットをHugging Face Hubからロード
18
+ ds = load_dataset(RESULT_DATASET_ID, split="train")
19
+ df = ds.to_pandas()
20
+ return df
 
 
 
21
 
22
 
23
+ def save_leaderboard_df(df: pd.DataFrame):
24
+ # DataFrameをDataset化
25
+ ds = Dataset.from_pandas(df, preserve_index=False)
26
+ # push_to_hubで更新
27
+ ds.push_to_hub(RESULT_DATASET_ID, token=HF_TOKEN, commit_message="Update leaderboard")
leaderboard/evaluation.py CHANGED
@@ -1,8 +1,63 @@
1
- # ダミーの評価関数
2
- def evaluate_model(model_path):
3
- """
4
- 提出モデルを評価してスコアを返す関数。
5
- 本番ではモデルをロードしてテストデータに基づくスコアを計算する。
6
- """
7
- # TODO: 実際の評価ロジックを実装する
8
- return 0.75 # 仮のスコア
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import tempfile
4
+ from urllib.parse import urlparse
5
+
6
+ import joblib
7
+ import numpy as np
8
+ from datasets import load_dataset
9
+ from dotenv import load_dotenv
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ load_dotenv()
13
+
14
+ TEST_DATA_ID = os.environ.get("TEST_DATA_ID", None)
15
+
16
+
17
+ def relative_error_loss(predicted_age, true_age):
18
+ true_age_safe = np.where(true_age == 0, 0.1, true_age)
19
+ relative_error = np.abs((true_age - predicted_age) / true_age_safe)
20
+ return np.mean(relative_error)
21
+
22
+
23
+ def parse_model_url(model_url: str):
24
+ parsed = urlparse(model_url)
25
+ path_parts = parsed.path.strip("/").split("/")
26
+
27
+ if len(path_parts) < 5:
28
+ raise ValueError("Unexpected URL format. Make sure it's a Hub URL with /resolve/main/ or /blob/main/")
29
+
30
+ repo_id = "/".join(path_parts[:2])
31
+ revision = path_parts[3]
32
+ filename = path_parts[4]
33
+
34
+ if not filename.endswith(".joblib"):
35
+ raise ValueError("The file must be a .joblib file.")
36
+
37
+ return repo_id, revision, filename
38
+
39
+
40
+ def evaluate_model(model_url: str) -> float:
41
+ if not model_url.startswith("https://huggingface.co/"):
42
+ raise ValueError("Invalid model URL. Must start with https://huggingface.co/")
43
+
44
+ repo_id, revision, filename = parse_model_url(model_url)
45
+
46
+ ds_test_meta = load_dataset(TEST_DATA_ID, "meta")
47
+ ds_test_main = load_dataset(TEST_DATA_ID, "main")
48
+
49
+ X_test = ds_test_main["test"].to_pandas().drop(columns=["SampleID"])
50
+ X_test = X_test.values.astype(np.float32)
51
+ y_test = np.array(ds_test_meta["test"]["Age"])
52
+
53
+ with tempfile.TemporaryDirectory() as tmpdir:
54
+ local_model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=tmpdir)
55
+ try:
56
+ model = joblib.load(local_model_path)
57
+ except Exception as e:
58
+ raise ValueError(f"Failed to load the model. Please check the .joblib file. Error: {e}")
59
+
60
+ predicted_age = model.predict(X_test)
61
+ score = relative_error_loss(predicted_age, y_test)
62
+
63
+ return score
leaderboard/submission.py CHANGED
@@ -1,32 +1,30 @@
1
- from datasets import Dataset
2
 
3
- from .dataset import load_or_initialize_leaderboard, save_leaderboard
4
  from .evaluation import evaluate_model
5
 
6
 
7
- def submit_model(model_name, model_file):
8
- """
9
- モデルの提出を処理する関数。
10
- 1. モデルを評価する。
11
- 2. リーダーボードにデータを追加。
12
- 3. ランクを計算して保存。
13
- """
14
- dataset = load_or_initialize_leaderboard()
15
 
16
- # モデル評価
17
- score = evaluate_model(model_file.name)
18
 
19
- # データに新しいモデルを追加
20
- new_entry = {"Model Name": model_name, "Score": score}
21
- dataset = dataset.add_item(new_entry)
22
 
23
- # ランク付け
24
- df = dataset.to_pandas()
25
- df = df.sort_values(by="Score", ascending=False).reset_index(drop=True)
 
 
 
26
  df["Rank"] = range(1, len(df) + 1)
27
 
28
- # データセットを更新・保存
29
- updated_dataset = Dataset.from_pandas(df)
30
- save_leaderboard(updated_dataset)
31
 
32
  return df
 
1
+ import pandas as pd
2
 
3
+ from .dataset import get_leaderboard_df, save_leaderboard_df
4
  from .evaluation import evaluate_model
5
 
6
 
7
+ def submit_model(model_name: str, model_url: str):
8
+ if not model_name.strip():
9
+ raise ValueError("Model name cannot be empty.")
10
+ if not model_url.strip():
11
+ raise ValueError("Model URL cannot be empty.")
 
 
 
12
 
13
+ # 現在のリーダーボードを取得
14
+ df = get_leaderboard_df()
15
 
16
+ # 新規評価
17
+ score = evaluate_model(model_url)
 
18
 
19
+ # 新規行追加
20
+ new_entry = {"Model Name": model_name.strip(), "Score (relative_error_loss)": score}
21
+ df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
22
+
23
+ # スコアが小さいほど良いと仮定し、昇順ソート&Rank再計算
24
+ df = df.sort_values(by="Score (relative_error_loss)", ascending=True).reset_index(drop=True)
25
  df["Rank"] = range(1, len(df) + 1)
26
 
27
+ # データセットをHugging Face Hubに反映
28
+ save_leaderboard_df(df)
 
29
 
30
  return df
pyproject.toml CHANGED
@@ -9,7 +9,12 @@ readme = "README.md"
9
  python = "^3.12"
10
  gradio = "^5.6.0"
11
  pandas = "^2.2.3"
12
- datasets = "^3.1.0"
 
 
 
 
 
13
 
14
  [build-system]
15
  requires = ["poetry-core"]
 
9
  python = "^3.12"
10
  gradio = "^5.6.0"
11
  pandas = "^2.2.3"
12
+ joblib = "^1.4.2"
13
+ scikit-learn = "^1.6.0"
14
+ datasets = "^3.2.0"
15
+
16
+ [tool.poetry.group.dev.dependencies]
17
+ python-dotenv = "^1.0.1"
18
 
19
  [build-system]
20
  requires = ["poetry-core"]
requirements.txt CHANGED
@@ -1,16 +1,71 @@
1
- APScheduler
2
- black
3
- datasets
4
- gradio
5
- gradio[oauth]
6
- gradio_leaderboard==0.0.9
7
- gradio_client
8
- huggingface-hub>=0.18.0
9
- matplotlib
10
- numpy
11
- pandas
12
- python-dateutil
13
- tqdm
14
- transformers
15
- tokenizers>=0.15.0
16
- sentencepiece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.12" and python_version < "4.0"
2
+ aiohappyeyeballs==2.4.4 ; python_version >= "3.12" and python_version < "4.0"
3
+ aiohttp==3.11.10 ; python_version >= "3.12" and python_version < "4.0"
4
+ aiosignal==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
5
+ annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
6
+ anyio==4.6.2.post1 ; python_version >= "3.12" and python_version < "4.0"
7
+ attrs==24.2.0 ; python_version >= "3.12" and python_version < "4.0"
8
+ audioop-lts==0.2.1 ; python_version >= "3.13" and python_version < "4.0"
9
+ certifi==2024.8.30 ; python_version >= "3.12" and python_version < "4.0"
10
+ charset-normalizer==3.4.0 ; python_version >= "3.12" and python_version < "4.0"
11
+ click==8.1.7 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
12
+ colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
13
+ datasets==3.2.0 ; python_version >= "3.12" and python_version < "4.0"
14
+ dill==0.3.8 ; python_version >= "3.12" and python_version < "4.0"
15
+ fastapi==0.115.5 ; python_version >= "3.12" and python_version < "4.0"
16
+ ffmpy==0.4.0 ; python_version >= "3.12" and python_version < "4.0"
17
+ filelock==3.16.1 ; python_version >= "3.12" and python_version < "4.0"
18
+ frozenlist==1.5.0 ; python_version >= "3.12" and python_version < "4.0"
19
+ fsspec==2024.9.0 ; python_version >= "3.12" and python_version < "4.0"
20
+ fsspec[http]==2024.9.0 ; python_version >= "3.12" and python_version < "4.0"
21
+ gradio-client==1.4.3 ; python_version >= "3.12" and python_version < "4.0"
22
+ gradio==5.6.0 ; python_version >= "3.12" and python_version < "4.0"
23
+ h11==0.14.0 ; python_version >= "3.12" and python_version < "4.0"
24
+ httpcore==1.0.7 ; python_version >= "3.12" and python_version < "4.0"
25
+ httpx==0.27.2 ; python_version >= "3.12" and python_version < "4.0"
26
+ huggingface-hub==0.26.2 ; python_version >= "3.12" and python_version < "4.0"
27
+ idna==3.10 ; python_version >= "3.12" and python_version < "4.0"
28
+ jinja2==3.1.4 ; python_version >= "3.12" and python_version < "4.0"
29
+ joblib==1.4.2 ; python_version >= "3.12" and python_version < "4.0"
30
+ markdown-it-py==3.0.0 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
31
+ markupsafe==2.1.5 ; python_version >= "3.12" and python_version < "4.0"
32
+ mdurl==0.1.2 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
33
+ multidict==6.1.0 ; python_version >= "3.12" and python_version < "4.0"
34
+ multiprocess==0.70.16 ; python_version >= "3.12" and python_version < "4.0"
35
+ numpy==2.1.3 ; python_version >= "3.12" and python_version < "4.0"
36
+ orjson==3.10.12 ; python_version >= "3.12" and python_version < "4.0"
37
+ packaging==24.2 ; python_version >= "3.12" and python_version < "4.0"
38
+ pandas==2.2.3 ; python_version >= "3.12" and python_version < "4.0"
39
+ pillow==11.0.0 ; python_version >= "3.12" and python_version < "4.0"
40
+ propcache==0.2.1 ; python_version >= "3.12" and python_version < "4.0"
41
+ pyarrow==18.1.0 ; python_version >= "3.12" and python_version < "4.0"
42
+ pydantic-core==2.27.1 ; python_version >= "3.12" and python_version < "4.0"
43
+ pydantic==2.10.1 ; python_version >= "3.12" and python_version < "4.0"
44
+ pydub==0.25.1 ; python_version >= "3.12" and python_version < "4.0"
45
+ pygments==2.18.0 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
46
+ python-dateutil==2.9.0.post0 ; python_version >= "3.12" and python_version < "4.0"
47
+ python-multipart==0.0.12 ; python_version >= "3.12" and python_version < "4.0"
48
+ pytz==2024.2 ; python_version >= "3.12" and python_version < "4.0"
49
+ pyyaml==6.0.2 ; python_version >= "3.12" and python_version < "4.0"
50
+ requests==2.32.3 ; python_version >= "3.12" and python_version < "4.0"
51
+ rich==13.9.4 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
52
+ ruff==0.8.0 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
53
+ safehttpx==0.1.1 ; python_version >= "3.12" and python_version < "4.0"
54
+ scikit-learn==1.6.0 ; python_version >= "3.12" and python_version < "4.0"
55
+ scipy==1.14.1 ; python_version >= "3.12" and python_version < "4.0"
56
+ semantic-version==2.10.0 ; python_version >= "3.12" and python_version < "4.0"
57
+ shellingham==1.5.4 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
58
+ six==1.16.0 ; python_version >= "3.12" and python_version < "4.0"
59
+ sniffio==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
60
+ starlette==0.41.3 ; python_version >= "3.12" and python_version < "4.0"
61
+ threadpoolctl==3.5.0 ; python_version >= "3.12" and python_version < "4.0"
62
+ tomlkit==0.12.0 ; python_version >= "3.12" and python_version < "4.0"
63
+ tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
64
+ typer==0.13.1 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
65
+ typing-extensions==4.12.2 ; python_version >= "3.12" and python_version < "4.0"
66
+ tzdata==2024.2 ; python_version >= "3.12" and python_version < "4.0"
67
+ urllib3==2.2.3 ; python_version >= "3.12" and python_version < "4.0"
68
+ uvicorn==0.32.1 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "emscripten"
69
+ websockets==12.0 ; python_version >= "3.12" and python_version < "4.0"
70
+ xxhash==3.5.0 ; python_version >= "3.12" and python_version < "4.0"
71
+ yarl==1.18.3 ; python_version >= "3.12" and python_version < "4.0"