QSBench commited on
Commit
a63cf6b
·
verified ·
1 Parent(s): 170aab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -28
app.py CHANGED
@@ -75,55 +75,90 @@ def sync_ml_metrics(ds_name: str):
75
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
76
  return gr.update(choices=valid_features, value=defaults)
77
 
 
 
 
 
 
 
 
 
78
  def train_classifier(ds_name: str, features: List[str]):
79
- if not features: return None, "### ❌ Select features first."
 
 
80
  assets = load_all_assets(ds_name)
81
  df = assets["df"]
82
 
83
- # Automatically determine available classes in the dataset, excluding empty values
84
- available_in_df = df['circuit_type_requested'].dropna().unique()
85
 
86
- # Filter: keep only those that are in our list of interests (case-insensitive)
87
- # Or simply take all available types if we want universality
88
- train_df = df[df['circuit_type_requested'].isin(available_in_df)].dropna(subset=features)
89
-
90
- if train_df.empty:
91
- return None, f"### ❌ Error: No data found for features {features}. Check if these columns are empty in the dataset."
92
-
93
- X, y = train_df[features], train_df['circuit_type_requested']
 
 
94
 
95
- # Check number of classes
96
- if len(y.unique()) < 2:
97
- return None, f"### ❌ Error: Need at least 2 classes to train. Found only: {y.unique()}"
 
 
98
 
 
99
  le = LabelEncoder()
100
  y_encoded = le.fit_transform(y)
101
-
102
- try:
103
- X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
104
- except ValueError as e:
105
- return None, f"### ❌ Split Error: {str(e)}"
106
 
107
- clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1).fit(X_train, y_train)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  preds = clf.predict(X_test)
109
 
 
110
  sns.set_theme(style="whitegrid")
111
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
112
 
 
113
  cm = confusion_matrix(y_test, preds)
114
- sns.heatmap(cm, annot=True, fmt='d', cmap='magma',
115
- xticklabels=le.classes_, yticklabels=le.classes_,
116
- ax=axes[0], cbar=False)
117
  axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
 
 
118
 
 
119
  importances = clf.feature_importances_
120
- idx = np.argsort(importances)[-10:]
121
- axes[1].barh([features[i] for i in idx], importances[idx], color='#3498db')
122
- axes[1].set_title("Feature Importance")
123
 
124
  plt.tight_layout()
125
- report = classification_report(y_test, preds, target_names=le.classes_)
126
- return fig, f"### 🏆 Results for {ds_name}\n```\n{report}\n```"
 
 
 
 
 
 
 
127
 
128
  def update_explorer(ds_name: str, split_name: str):
129
  assets = load_all_assets(ds_name)
 
75
  defaults = [f for f in ["gate_entropy", "meyer_wallach", "adjacency", "depth", "cx_count"] if f in valid_features]
76
  return gr.update(choices=valid_features, value=defaults)
77
 
78
+ Судя по ошибке Found only: ['mixed'], в вашем столбце circuit_type_requested вместо конкретных названий семейств (QFT, HEA и т.д.) записано значение 'mixed'. Это часто случается в демонстрационных подмножествах, где данные уже перемешаны и помечены общим тегом.
79
+
80
+ Для классификации нам нужны исходные метки. В датасетах QSBench они обычно находятся в столбце circuit_type_resolved.
81
+
82
+ Вот обновленный код функции train_classifier с исправленной логикой выбора столбца и более надежной обработкой ошибок.
83
+ Исправленный код (App Code)
84
+ Python
85
+
86
  def train_classifier(ds_name: str, features: List[str]):
87
+ if not features:
88
+ return None, "### ❌ Error: No features selected. Please pick structural metrics."
89
+
90
  assets = load_all_assets(ds_name)
91
  df = assets["df"]
92
 
93
+ # Try 'resolved' column first as 'requested' might contain 'mixed' in demo shards
94
+ target_col = 'circuit_type_resolved' if 'circuit_type_resolved' in df.columns else 'circuit_type_requested'
95
 
96
+ # Clean data: remove NaNs and ensure we have valid target strings
97
+ train_df = df.dropna(subset=features + [target_col])
98
+
99
+ # Filter out rows where the target might be 'mixed' or generic if others are available
100
+ unique_types = train_df[target_col].unique()
101
+ if 'mixed' in unique_types and len(unique_types) > 1:
102
+ train_df = train_df[train_df[target_col] != 'mixed']
103
+
104
+ X = train_df[features]
105
+ y = train_df[target_col]
106
 
107
+ # Verification: Do we have at least 2 distinct classes to perform classification?
108
+ current_classes = y.unique()
109
+ if len(current_classes) < 2:
110
+ return None, f"### ❌ Classification Error\nFound only one class: `{current_classes}` in column `{target_col}`. " \
111
+ "Try a different dataset or check if the source file has labels."
112
 
113
+ # Encode labels to integers
114
  le = LabelEncoder()
115
  y_encoded = le.fit_transform(y)
116
+ class_names = le.classes_
 
 
 
 
117
 
118
+ # Split dataset
119
+ try:
120
+ X_train, X_test, y_train, y_test = train_test_split(
121
+ X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
122
+ )
123
+ except ValueError:
124
+ # Fallback if stratify fails due to very small class sizes
125
+ X_train, X_test, y_train, y_test = train_test_split(
126
+ X, y_encoded, test_size=0.2, random_state=42
127
+ )
128
+
129
+ # Train Random Forest Classifier
130
+ clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=42)
131
+ clf.fit(X_train, y_train)
132
  preds = clf.predict(X_test)
133
 
134
+ # Visuals
135
  sns.set_theme(style="whitegrid")
136
  fig, axes = plt.subplots(1, 2, figsize=(20, 8))
137
 
138
+ # Plot 1: Confusion Matrix
139
  cm = confusion_matrix(y_test, preds)
140
+ sns.heatmap(cm, annot=True, fmt='d', cmap='viridis',
141
+ xticklabels=class_names, yticklabels=class_names, ax=axes[0], cbar=False)
 
142
  axes[0].set_title(f"Confusion Matrix (Acc: {accuracy_score(y_test, preds):.2%})")
143
+ axes[0].set_xlabel("Predicted Label")
144
+ axes[0].set_ylabel("True Label")
145
 
146
+ # Plot 2: Feature Importance
147
  importances = clf.feature_importances_
148
+ indices = np.argsort(importances)[-10:]
149
+ axes[1].barh([features[i] for i in indices], importances[indices], color='#2ecc71')
150
+ axes[1].set_title("Top-10 Discriminative Features")
151
 
152
  plt.tight_layout()
153
+
154
+ # Generate text report
155
+ report_dict = classification_report(y_test, preds, target_names=class_names)
156
+ summary = f"### 🏆 Classifier Results: {ds_name}\n" \
157
+ f"**Target Column used:** `{target_col}`\n" \
158
+ f"**Accuracy:** {accuracy_score(y_test, preds):.2%}\n\n" \
159
+ f"**Report:**\n```\n{report_dict}\n```"
160
+
161
+ return fig, summary
162
 
163
  def update_explorer(ds_name: str, split_name: str):
164
  assets = load_all_assets(ds_name)