Spaces:
Runtime error
Runtime error
lowered threshold
Browse files
backend/disentangle_concepts.py
CHANGED
@@ -25,8 +25,8 @@ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=
|
|
25 |
clf = LogisticRegression(random_state=0, C=C)
|
26 |
clf.fit(x_train, y_train)
|
27 |
print('Val performance logistic regression', clf.score(x_val, y_val))
|
28 |
-
imp_features = (np.abs(clf.coef_) > 0.
|
29 |
-
imp_nodes = np.where(np.abs(clf.coef_) > 0.
|
30 |
return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes
|
31 |
|
32 |
|
|
|
25 |
clf = LogisticRegression(random_state=0, C=C)
|
26 |
clf.fit(x_train, y_train)
|
27 |
print('Val performance logistic regression', clf.score(x_val, y_val))
|
28 |
+
imp_features = (np.abs(clf.coef_) > 0.15).sum()
|
29 |
+
imp_nodes = np.where(np.abs(clf.coef_) > 0.15)[1]
|
30 |
return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes
|
31 |
|
32 |
|