Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
d81d6fd
1
Parent(s):
d2ff920
add conditions to extract labels from dataset
Browse files- text_classification.py +9 -1
text_classification.py
CHANGED
@@ -15,7 +15,15 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
|
|
15 |
try:
|
16 |
ds = datasets.load_dataset(dataset_id, dataset_config)[split]
|
17 |
dataset_features = ds.features
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
features = [f for f in dataset_features.keys() if f != "label"]
|
20 |
return labels, features
|
21 |
except Exception as e:
|
|
|
15 |
try:
|
16 |
ds = datasets.load_dataset(dataset_id, dataset_config)[split]
|
17 |
dataset_features = ds.features
|
18 |
+
label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
|
19 |
+
if len(label_keys) == 0:
|
20 |
+
raise ValueError("Dataset does not have label column")
|
21 |
+
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
22 |
+
if hasattr(dataset_features[label_keys[0]], 'feature'):
|
23 |
+
label_feat = dataset_features[label_keys[0]].feature
|
24 |
+
labels = label_feat.names
|
25 |
+
else:
|
26 |
+
labels = [dataset_features[label_keys[0]].names]
|
27 |
features = [f for f in dataset_features.keys() if f != "label"]
|
28 |
return labels, features
|
29 |
except Exception as e:
|