QSBench commited on
Commit
3bf4374
·
verified ·
1 Parent(s): a1152bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -61
app.py CHANGED
@@ -7,10 +7,9 @@ from datasets import load_dataset
7
  from sklearn.ensemble import RandomForestRegressor
8
  from sklearn.metrics import mean_absolute_error, r2_score
9
  from sklearn.model_selection import train_test_split
10
- from pathlib import Path
11
 
12
  # =========================================================
13
- # CONFIG & REPOSITORIES
14
  # =========================================================
15
  DATASET_MAP = {
16
  "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
@@ -21,7 +20,6 @@ DATASET_MAP = {
21
 
22
  TARGET_COL = "ideal_expval_Z_global"
23
 
24
- # Список не-числовых колонок и таргетов для исключения из обучения
25
  EXCLUDE_COLS = {
26
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
27
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
@@ -32,7 +30,7 @@ EXCLUDE_COLS = {
32
  dataset_cache = {}
33
 
34
  # =========================================================
35
- # DATA UTILS
36
  # =========================================================
37
  def get_df(dataset_key):
38
  if dataset_key not in dataset_cache:
@@ -43,37 +41,44 @@ def get_df(dataset_key):
43
 
44
  def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
45
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
46
- return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_") and not c.startswith("sign_")]
 
47
 
48
  # =========================================================
49
- # TAB FUNCTIONS
50
  # =========================================================
51
  def update_explorer(dataset_name, split_name):
52
  df = get_df(dataset_name)
53
-
54
- # Пытаемся найти уникальные сплиты, если их нет — ставим 'train'
55
  splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
56
  filtered = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
57
 
58
- # Данные из ваших колонок в CSV
59
- qasm_raw = filtered["qasm_raw"].iloc[0] if "qasm_raw" in filtered.columns else "// No raw QASM found"
60
- qasm_tr = filtered["qasm_transpiled"].iloc[0] if "qasm_transpiled" in filtered.columns else "// No transpiled QASM found"
61
 
62
- # Список признаков для вкладки ML
63
  features = get_numeric_feature_cols(df)
 
 
64
 
65
- return gr.update(choices=splits), filtered, qasm_raw, qasm_tr, gr.update(choices=features, value=features[:5])
66
 
67
  def run_model_demo(dataset_name, selected_features):
68
- if not selected_features or len(selected_features) == 0:
69
- return None, "### ⚠️ Please select at least one feature from the list."
70
-
71
  df = get_df(dataset_name)
 
 
 
 
 
 
 
72
  target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
73
 
74
- work_df = df.dropna(subset=selected_features + [target]).reset_index(drop=True)
75
- X, y = work_df[selected_features], work_df[target]
 
76
 
 
 
 
77
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
78
 
79
  model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
@@ -83,73 +88,61 @@ def run_model_demo(dataset_name, selected_features):
83
  sns.set_theme(style="whitegrid")
84
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
85
 
86
- # 1. Parity Plot
87
  ax1.scatter(y_test, preds, alpha=0.4, color='#636EFA')
88
  ax1.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
89
- ax1.set_xlabel("Ground Truth")
90
- ax1.set_ylabel("Predictions")
91
- ax1.set_title(f"Prediction Accuracy (R²={r2_score(y_test, preds):.3f})")
92
 
93
- # 2. Feature Importance
94
  importances = model.feature_importances_
95
- indices = np.argsort(importances)
96
  ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
97
  ax2.set_yticks(range(len(indices)))
98
- ax2.set_yticklabels([selected_features[i] for i in indices])
99
- ax2.set_title("Structural Feature Importance")
100
 
101
- # 3. Residuals
102
  sns.histplot(y_test - preds, kde=True, ax=ax3, color='#00CC96')
103
- ax3.set_title("Error Distribution (Residuals)")
104
 
105
  plt.tight_layout()
106
- return fig, f"### Model performance on {dataset_name}\n**MAE:** {mean_absolute_error(y_test, preds):.4f} | **Features used:** {len(selected_features)}"
107
 
108
  # =========================================================
109
- # INTERFACE
110
  # =========================================================
111
- with gr.Blocks(title="QSBench Unified Explorer") as demo:
112
- gr.Markdown("# 🌌 QSBench: Quantum Synthetic Benchmark Explorer")
113
 
114
  with gr.Tabs():
115
- with gr.TabItem("🔎 Dataset Explorer"):
116
  with gr.Row():
117
- ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Select Dataset")
118
- split_selector = gr.Dropdown(choices=["train"], value="train", label="Data Split")
119
 
120
- # Параметр overflow_row_behaviour удален для совместимости с Gradio 6
121
  data_table = gr.Dataframe(interactive=False)
122
 
123
  with gr.Row():
124
- qasm_raw_view = gr.Code(label="Raw QASM (Source)", language="python", lines=12)
125
- qasm_tr_view = gr.Code(label="Transpiled QASM (Hardware-ready)", language="python", lines=12)
126
 
127
- with gr.TabItem("🤖 ML Baseline Demo"):
128
  with gr.Row():
129
  with gr.Column(scale=1):
130
- gr.Markdown("### Training Settings")
131
- model_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset")
132
- feature_selector = gr.CheckboxGroup(label="Select Structural Metrics", choices=[])
133
- train_btn = gr.Button("Run Training", variant="primary")
134
  with gr.Column(scale=2):
135
- plot_output = gr.Plot()
136
- text_output = gr.Markdown()
137
-
138
- gr.Markdown("""
139
- ---
140
- ### 🔬 Research Resources
141
- - **GitHub**: [QSBench/QSBench-Demo](https://github.com/QSBench/QSBench-Demo)
142
- - **Website**: [qsbench.github.io](https://qsbench.github.io)
143
- - **Hugging Face**: [Explore all datasets](https://huggingface.co/QSBench)
144
- """)
145
-
146
- # Связи событий
147
- ds_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, feature_selector])
148
- split_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, feature_selector])
149
- train_btn.click(run_model_demo, [model_ds_selector, feature_selector], [plot_output, text_output])
150
-
151
- # Начальная загрузка при старте Space
152
- demo.load(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, feature_selector])
153
 
154
  if __name__ == "__main__":
155
  demo.launch(theme=gr.themes.Soft())
 
7
  from sklearn.ensemble import RandomForestRegressor
8
  from sklearn.metrics import mean_absolute_error, r2_score
9
  from sklearn.model_selection import train_test_split
 
10
 
11
  # =========================================================
12
+ # CONFIG
13
  # =========================================================
14
  DATASET_MAP = {
15
  "Core (Clean)": "QSBench/QSBench-Core-v1.0.0-demo",
 
20
 
21
  TARGET_COL = "ideal_expval_Z_global"
22
 
 
23
  EXCLUDE_COLS = {
24
  "sample_id", "sample_seed", "circuit_hash", "split", "circuit_qasm",
25
  "qasm_raw", "qasm_transpiled", "circuit_type_resolved", "circuit_type_requested",
 
30
  dataset_cache = {}
31
 
32
  # =========================================================
33
+ # UTILS
34
  # =========================================================
35
  def get_df(dataset_key):
36
  if dataset_key not in dataset_cache:
 
41
 
42
  def get_numeric_feature_cols(df: pd.DataFrame) -> list[str]:
43
  numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
44
+ # Убираем все таргеты и нерелевантные колонки
45
+ return [c for c in numeric_cols if c not in EXCLUDE_COLS and not c.startswith("error_") and "expval" not in c]
46
 
47
  # =========================================================
48
+ # LOGIC
49
  # =========================================================
50
  def update_explorer(dataset_name, split_name):
51
  df = get_df(dataset_name)
 
 
52
  splits = df["split"].unique().tolist() if "split" in df.columns else ["train"]
53
  filtered = df[df["split"] == split_name].head(10) if "split" in df.columns else df.head(10)
54
 
55
+ qasm_raw = filtered["qasm_raw"].iloc[0] if "qasm_raw" in filtered.columns else "// N/A"
56
+ qasm_tr = filtered["qasm_transpiled"].iloc[0] if "qasm_transpiled" in filtered.columns else "// N/A"
 
57
 
 
58
  features = get_numeric_feature_cols(df)
59
+ # По умолчанию выбираем первые 8 признаков (обычно это n_qubits, depth и базовые гейты)
60
+ default_features = features[:8]
61
 
62
+ return gr.update(choices=splits), filtered, qasm_raw, qasm_tr, gr.update(choices=features, value=default_features)
63
 
64
  def run_model_demo(dataset_name, selected_features):
 
 
 
65
  df = get_df(dataset_name)
66
+
67
+ # КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: фильтруем признаки, которые реально есть в этом датасете
68
+ valid_features = [f for f in selected_features if f in df.columns]
69
+
70
+ if not valid_features:
71
+ return None, "### ⚠️ No valid features selected for this dataset."
72
+
73
  target = TARGET_COL if TARGET_COL in df.columns else df.filter(like="expval").columns[0]
74
 
75
+ # Подготовка данных
76
+ work_df = df.dropna(subset=valid_features + [target]).reset_index(drop=True)
77
+ X, y = work_df[valid_features], work_df[target]
78
 
79
+ if len(work_df) < 50:
80
+ return None, "### ⚠️ Not enough data rows to train."
81
+
82
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
83
 
84
  model = RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1, random_state=42)
 
88
  sns.set_theme(style="whitegrid")
89
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
90
 
91
+ # Parity
92
  ax1.scatter(y_test, preds, alpha=0.4, color='#636EFA')
93
  ax1.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
94
+ ax1.set_title(f" = {r2_score(y_test, preds):.3f}")
95
+ ax1.set_xlabel("Actual")
96
+ ax1.set_ylabel("Predicted")
97
 
98
+ # Importance
99
  importances = model.feature_importances_
100
+ indices = np.argsort(importances)[-10:] # Только топ-10 для красоты
101
  ax2.barh(range(len(indices)), importances[indices], color='#EF553B')
102
  ax2.set_yticks(range(len(indices)))
103
+ ax2.set_yticklabels([valid_features[i] for i in indices])
104
+ ax2.set_title("Top Feature Importance")
105
 
106
+ # Residuals
107
  sns.histplot(y_test - preds, kde=True, ax=ax3, color='#00CC96')
108
+ ax3.set_title("Error Distribution")
109
 
110
  plt.tight_layout()
111
+ return fig, f"### Train Stats: {dataset_name}\n**MAE:** {mean_absolute_error(y_test, preds):.4f}"
112
 
113
  # =========================================================
114
+ # UI
115
  # =========================================================
116
+ with gr.Blocks() as demo:
117
+ gr.Markdown("# 🌌 QSBench Unified Explorer")
118
 
119
  with gr.Tabs():
120
+ with gr.TabItem("🔎 Explorer"):
121
  with gr.Row():
122
+ ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Dataset")
123
+ split_selector = gr.Dropdown(choices=["train"], value="train", label="Split")
124
 
 
125
  data_table = gr.Dataframe(interactive=False)
126
 
127
  with gr.Row():
128
+ qasm_raw_view = gr.Code(label="Raw QASM", language="python", lines=10)
129
+ qasm_tr_view = gr.Code(label="Transpiled QASM", language="python", lines=10)
130
 
131
+ with gr.TabItem("🤖 ML Demo"):
132
  with gr.Row():
133
  with gr.Column(scale=1):
134
+ m_ds_selector = gr.Dropdown(choices=list(DATASET_MAP.keys()), value="Core (Clean)", label="Target Dataset")
135
+ f_selector = gr.CheckboxGroup(label="Features", choices=[])
136
+ train_btn = gr.Button("Train", variant="primary")
 
137
  with gr.Column(scale=2):
138
+ plot_out = gr.Plot()
139
+ text_out = gr.Markdown()
140
+
141
+ # Ссылки
142
+ ds_selector.change(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, f_selector])
143
+ train_btn.click(run_model_demo, [m_ds_selector, f_selector], [plot_out, text_out])
144
+
145
+ demo.load(update_explorer, [ds_selector, split_selector], [split_selector, data_table, qasm_raw_view, qasm_tr_view, f_selector])
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  demo.launch(theme=gr.themes.Soft())