Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
21e0bb3
1
Parent(s):
bedf925
Fix for flattened raw config
Browse files- app.py +1 -1
- app_text_classification.py +5 -8
- io_utils.py +2 -5
- text_classification.py +4 -3
- text_classification_ui_helpers.py +44 -24
app.py
CHANGED
@@ -10,7 +10,7 @@ from run_jobs import start_process_run_job, stop_thread
|
|
10 |
try:
|
11 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
|
12 |
with gr.Tab("Text Classification"):
|
13 |
-
get_demo_text_classification(
|
14 |
with gr.Tab("Leaderboard"):
|
15 |
get_demo_leaderboard()
|
16 |
with gr.Tab("Logs(Debug)"):
|
|
|
10 |
try:
|
11 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
|
12 |
with gr.Tab("Text Classification"):
|
13 |
+
get_demo_text_classification()
|
14 |
with gr.Tab("Leaderboard"):
|
15 |
get_demo_leaderboard()
|
16 |
with gr.Tab("Logs(Debug)"):
|
app_text_classification.py
CHANGED
@@ -12,7 +12,7 @@ from text_classification_ui_helpers import (check_dataset_and_get_config,
|
|
12 |
write_column_mapping_to_config)
|
13 |
from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
|
14 |
|
15 |
-
MAX_LABELS =
|
16 |
MAX_FEATURES = 20
|
17 |
|
18 |
EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
@@ -20,7 +20,7 @@ EXAMPLE_DATA_ID = "tweet_eval"
|
|
20 |
CONFIG_PATH = "./config.yaml"
|
21 |
|
22 |
|
23 |
-
def get_demo(
|
24 |
with gr.Row():
|
25 |
gr.Markdown(INTRODUCTION_MD)
|
26 |
uid_label = gr.Textbox(
|
@@ -55,9 +55,11 @@ def get_demo(demo):
|
|
55 |
column_mappings = []
|
56 |
with gr.Row():
|
57 |
with gr.Column():
|
|
|
58 |
for _ in range(MAX_LABELS):
|
59 |
column_mappings.append(gr.Dropdown(visible=False))
|
60 |
with gr.Column():
|
|
|
61 |
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
|
62 |
column_mappings.append(gr.Dropdown(visible=False))
|
63 |
|
@@ -138,9 +140,6 @@ def get_demo(demo):
|
|
138 |
triggers=[label.change for label in column_mappings],
|
139 |
fn=write_column_mapping_to_config,
|
140 |
inputs=[
|
141 |
-
dataset_id_input,
|
142 |
-
dataset_config_input,
|
143 |
-
dataset_split_input,
|
144 |
uid_label,
|
145 |
*column_mappings,
|
146 |
],
|
@@ -151,9 +150,6 @@ def get_demo(demo):
|
|
151 |
triggers=[label.input for label in column_mappings],
|
152 |
fn=write_column_mapping_to_config,
|
153 |
inputs=[
|
154 |
-
dataset_id_input,
|
155 |
-
dataset_config_input,
|
156 |
-
dataset_split_input,
|
157 |
uid_label,
|
158 |
*column_mappings,
|
159 |
],
|
@@ -172,6 +168,7 @@ def get_demo(demo):
|
|
172 |
dataset_id_input,
|
173 |
dataset_config_input,
|
174 |
dataset_split_input,
|
|
|
175 |
],
|
176 |
outputs=[
|
177 |
example_input,
|
|
|
12 |
write_column_mapping_to_config)
|
13 |
from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
|
14 |
|
15 |
+
MAX_LABELS = 40
|
16 |
MAX_FEATURES = 20
|
17 |
|
18 |
EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
|
|
20 |
CONFIG_PATH = "./config.yaml"
|
21 |
|
22 |
|
23 |
+
def get_demo():
|
24 |
with gr.Row():
|
25 |
gr.Markdown(INTRODUCTION_MD)
|
26 |
uid_label = gr.Textbox(
|
|
|
55 |
column_mappings = []
|
56 |
with gr.Row():
|
57 |
with gr.Column():
|
58 |
+
gr.Markdown("# Label Mapping")
|
59 |
for _ in range(MAX_LABELS):
|
60 |
column_mappings.append(gr.Dropdown(visible=False))
|
61 |
with gr.Column():
|
62 |
+
gr.Markdown("# Feature Mapping")
|
63 |
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
|
64 |
column_mappings.append(gr.Dropdown(visible=False))
|
65 |
|
|
|
140 |
triggers=[label.change for label in column_mappings],
|
141 |
fn=write_column_mapping_to_config,
|
142 |
inputs=[
|
|
|
|
|
|
|
143 |
uid_label,
|
144 |
*column_mappings,
|
145 |
],
|
|
|
150 |
triggers=[label.input for label in column_mappings],
|
151 |
fn=write_column_mapping_to_config,
|
152 |
inputs=[
|
|
|
|
|
|
|
153 |
uid_label,
|
154 |
*column_mappings,
|
155 |
],
|
|
|
168 |
dataset_id_input,
|
169 |
dataset_config_input,
|
170 |
dataset_split_input,
|
171 |
+
uid_label,
|
172 |
],
|
173 |
outputs=[
|
174 |
example_input,
|
io_utils.py
CHANGED
@@ -76,7 +76,6 @@ def read_column_mapping(uid):
|
|
76 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
77 |
if config:
|
78 |
column_mapping = config.get("column_mapping", dict())
|
79 |
-
f.close()
|
80 |
return column_mapping
|
81 |
|
82 |
|
@@ -84,7 +83,6 @@ def read_column_mapping(uid):
|
|
84 |
def write_column_mapping(mapping, uid):
|
85 |
with open(get_yaml_path(uid), "r") as f:
|
86 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
87 |
-
f.close()
|
88 |
|
89 |
if config is None:
|
90 |
return
|
@@ -92,10 +90,9 @@ def write_column_mapping(mapping, uid):
|
|
92 |
del config["column_mapping"]
|
93 |
else:
|
94 |
config["column_mapping"] = mapping
|
95 |
-
|
96 |
with open(get_yaml_path(uid), "w") as f:
|
97 |
-
yaml
|
98 |
-
|
99 |
|
100 |
|
101 |
# convert column mapping dataframe to json
|
|
|
76 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
77 |
if config:
|
78 |
column_mapping = config.get("column_mapping", dict())
|
|
|
79 |
return column_mapping
|
80 |
|
81 |
|
|
|
83 |
def write_column_mapping(mapping, uid):
|
84 |
with open(get_yaml_path(uid), "r") as f:
|
85 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
86 |
|
87 |
if config is None:
|
88 |
return
|
|
|
90 |
del config["column_mapping"]
|
91 |
else:
|
92 |
config["column_mapping"] = mapping
|
|
|
93 |
with open(get_yaml_path(uid), "w") as f:
|
94 |
+
# yaml Dumper will by default sort the keys
|
95 |
+
yaml.dump(config, f, Dumper=Dumper, sort_keys=False)
|
96 |
|
97 |
|
98 |
# convert column mapping dataframe to json
|
text_classification.py
CHANGED
@@ -16,15 +16,16 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
|
|
16 |
ds = datasets.load_dataset(dataset_id, dataset_config)[split]
|
17 |
dataset_features = ds.features
|
18 |
label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
|
19 |
-
if len(label_keys) == 0:
|
20 |
-
|
|
|
21 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
22 |
if hasattr(dataset_features[label_keys[0]], 'feature'):
|
23 |
label_feat = dataset_features[label_keys[0]].feature
|
24 |
labels = label_feat.names
|
25 |
else:
|
26 |
labels = [dataset_features[label_keys[0]].names]
|
27 |
-
features = [f for f in dataset_features.keys() if f
|
28 |
return labels, features
|
29 |
except Exception as e:
|
30 |
logging.warning(
|
|
|
16 |
ds = datasets.load_dataset(dataset_id, dataset_config)[split]
|
17 |
dataset_features = ds.features
|
18 |
label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
|
19 |
+
if len(label_keys) == 0: # no labels found
|
20 |
+
# return everything for post processing
|
21 |
+
return list(dataset_features.keys()), list(dataset_features.keys())
|
22 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
23 |
if hasattr(dataset_features[label_keys[0]], 'feature'):
|
24 |
label_feat = dataset_features[label_keys[0]].feature
|
25 |
labels = label_feat.names
|
26 |
else:
|
27 |
labels = [dataset_features[label_keys[0]].names]
|
28 |
+
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
29 |
return labels, features
|
30 |
except Exception as e:
|
31 |
logging.warning(
|
text_classification_ui_helpers.py
CHANGED
@@ -18,7 +18,7 @@ from wordings import (CHECK_CONFIG_OR_SPLIT_RAW,
|
|
18 |
CONFIRM_MAPPING_DETAILS_FAIL_RAW,
|
19 |
MAPPING_STYLED_ERROR_WARNING)
|
20 |
|
21 |
-
MAX_LABELS =
|
22 |
MAX_FEATURES = 20
|
23 |
|
24 |
HF_REPO_ID = "HF_REPO_ID"
|
@@ -68,46 +68,62 @@ def deselect_run_inference(run_local):
|
|
68 |
|
69 |
|
70 |
def write_column_mapping_to_config(
|
71 |
-
|
72 |
):
|
73 |
# TODO: Substitute 'text' with more features for zero-shot
|
74 |
# we are not using ds features because we only support "text" for now
|
75 |
-
|
76 |
-
|
77 |
-
)
|
78 |
if labels is None:
|
79 |
return
|
|
|
|
|
80 |
|
81 |
-
all_mappings = dict()
|
82 |
-
|
83 |
-
if "labels" not in all_mappings.keys():
|
84 |
-
all_mappings["labels"] = dict()
|
85 |
-
for i, label in enumerate(labels[:MAX_LABELS]):
|
86 |
-
if label:
|
87 |
-
all_mappings["labels"][label] = ds_labels[i % len(ds_labels)]
|
88 |
-
if "features" not in all_mappings.keys():
|
89 |
-
all_mappings["features"] = dict()
|
90 |
-
for _, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]):
|
91 |
-
if feat:
|
92 |
-
# TODO: Substitute 'text' with more features for zero-shot
|
93 |
-
all_mappings["features"]["text"] = feat
|
94 |
write_column_mapping(all_mappings, uid)
|
95 |
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
model_labels = list(model_id2label.values())
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
lables = [
|
101 |
gr.Dropdown(
|
102 |
label=f"{label}",
|
103 |
choices=model_labels,
|
104 |
-
value=model_id2label[i %
|
105 |
interactive=True,
|
106 |
visible=True,
|
107 |
)
|
108 |
-
for i, label in enumerate(ds_labels
|
109 |
]
|
110 |
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
|
|
|
|
|
111 |
# TODO: Substitute 'text' with more features for zero-shot
|
112 |
features = [
|
113 |
gr.Dropdown(
|
@@ -122,11 +138,14 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
|
|
122 |
features += [
|
123 |
gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
|
124 |
]
|
|
|
|
|
|
|
125 |
return lables + features
|
126 |
|
127 |
|
128 |
def check_model_and_show_prediction(
|
129 |
-
model_id, dataset_id, dataset_config, dataset_split
|
130 |
):
|
131 |
ppl = check_model(model_id)
|
132 |
if ppl is None or not isinstance(ppl, TextClassificationPipeline):
|
@@ -168,6 +187,7 @@ def check_model_and_show_prediction(
|
|
168 |
ds_labels,
|
169 |
ds_features,
|
170 |
model_id2label,
|
|
|
171 |
)
|
172 |
|
173 |
# when labels or features are not aligned
|
|
|
18 |
CONFIRM_MAPPING_DETAILS_FAIL_RAW,
|
19 |
MAPPING_STYLED_ERROR_WARNING)
|
20 |
|
21 |
+
MAX_LABELS = 40
|
22 |
MAX_FEATURES = 20
|
23 |
|
24 |
HF_REPO_ID = "HF_REPO_ID"
|
|
|
68 |
|
69 |
|
70 |
def write_column_mapping_to_config(
|
71 |
+
uid, *labels
|
72 |
):
|
73 |
# TODO: Substitute 'text' with more features for zero-shot
|
74 |
# we are not using ds features because we only support "text" for now
|
75 |
+
all_mappings = read_column_mapping(uid)
|
76 |
+
|
|
|
77 |
if labels is None:
|
78 |
return
|
79 |
+
all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
|
80 |
+
all_mappings = export_mappings(all_mappings, "features", ["text"], labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)])
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
write_column_mapping(all_mappings, uid)
|
83 |
|
84 |
+
def export_mappings(all_mappings, key, subkeys, values):
|
85 |
+
if key not in all_mappings.keys():
|
86 |
+
all_mappings[key] = dict()
|
87 |
+
if subkeys is None:
|
88 |
+
subkeys = list(all_mappings[key].keys())
|
89 |
+
|
90 |
+
if not subkeys:
|
91 |
+
logging.debug(f"subkeys is empty for {key}")
|
92 |
+
return all_mappings
|
93 |
+
|
94 |
+
for i, subkey in enumerate(subkeys):
|
95 |
+
if subkey:
|
96 |
+
all_mappings[key][subkey] = values[i % len(values)]
|
97 |
+
return all_mappings
|
98 |
+
|
99 |
+
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid):
|
100 |
model_labels = list(model_id2label.values())
|
101 |
+
all_mappings = read_column_mapping(uid)
|
102 |
+
# For flattened raw datasets with no labels
|
103 |
+
# check if there are shared labels between model and dataset
|
104 |
+
shared_labels = set(model_labels).intersection(set(ds_labels))
|
105 |
+
if shared_labels:
|
106 |
+
ds_labels = list(shared_labels)
|
107 |
+
if len(ds_labels) > MAX_LABELS:
|
108 |
+
ds_labels = ds_labels[:MAX_LABELS]
|
109 |
+
gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
|
110 |
+
|
111 |
+
ds_labels.sort()
|
112 |
+
model_labels.sort()
|
113 |
+
|
114 |
lables = [
|
115 |
gr.Dropdown(
|
116 |
label=f"{label}",
|
117 |
choices=model_labels,
|
118 |
+
value=model_id2label[i % len(model_labels)],
|
119 |
interactive=True,
|
120 |
visible=True,
|
121 |
)
|
122 |
+
for i, label in enumerate(ds_labels)
|
123 |
]
|
124 |
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
|
125 |
+
all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)
|
126 |
+
|
127 |
# TODO: Substitute 'text' with more features for zero-shot
|
128 |
features = [
|
129 |
gr.Dropdown(
|
|
|
138 |
features += [
|
139 |
gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
|
140 |
]
|
141 |
+
all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
|
142 |
+
write_column_mapping(all_mappings, uid)
|
143 |
+
|
144 |
return lables + features
|
145 |
|
146 |
|
147 |
def check_model_and_show_prediction(
|
148 |
+
model_id, dataset_id, dataset_config, dataset_split, uid
|
149 |
):
|
150 |
ppl = check_model(model_id)
|
151 |
if ppl is None or not isinstance(ppl, TextClassificationPipeline):
|
|
|
187 |
ds_labels,
|
188 |
ds_features,
|
189 |
model_id2label,
|
190 |
+
uid,
|
191 |
)
|
192 |
|
193 |
# when labels or features are not aligned
|