Spaces:
Sleeping
Sleeping
inoki-giskard
commited on
Commit
·
fc361e6
1
Parent(s):
e71d7e4
Format text classification and use warning
Browse files- run_jobs.py +1 -1
- text_classification.py +33 -16
- text_classification_ui_helpers.py +9 -9
run_jobs.py
CHANGED
@@ -69,7 +69,7 @@ def prepare_env_and_get_command(
|
|
69 |
)
|
70 |
logger.info(f"Using {executable} as executable")
|
71 |
except Exception as e:
|
72 |
-
logger.
|
73 |
executable = "giskard_scanner"
|
74 |
|
75 |
command = [
|
|
|
69 |
)
|
70 |
logger.info(f"Using {executable} as executable")
|
71 |
except Exception as e:
|
72 |
+
logger.warning(f"Create env failed due to {e}, using the current env as fallback.")
|
73 |
executable = "giskard_scanner"
|
74 |
|
75 |
command = [
|
text_classification.py
CHANGED
@@ -14,6 +14,7 @@ AUTH_CHECK_URL = "https://huggingface.co/api/whoami-v2"
|
|
14 |
|
15 |
logger = logging.getLogger(__file__)
|
16 |
|
|
|
17 |
class HuggingFaceInferenceAPIResponse:
|
18 |
def __init__(self, message):
|
19 |
self.message = message
|
@@ -25,7 +26,7 @@ def get_labels_and_features_from_dataset(ds):
|
|
25 |
label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
|
26 |
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
27 |
|
28 |
-
if len(label_keys) == 0:
|
29 |
# return everything for post processing
|
30 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
31 |
|
@@ -40,11 +41,10 @@ def get_labels_and_features_from_dataset(ds):
|
|
40 |
labels = dataset_features[label_keys[0]].names
|
41 |
return labels, features, label_keys
|
42 |
except Exception as e:
|
43 |
-
logging.warning(
|
44 |
-
f"Get Labels/Features Failed for dataset: {e}"
|
45 |
-
)
|
46 |
return None, None, None
|
47 |
|
|
|
48 |
def check_model_task(model_id):
|
49 |
# check if model is valid on huggingface
|
50 |
try:
|
@@ -55,6 +55,7 @@ def check_model_task(model_id):
|
|
55 |
except Exception:
|
56 |
return None
|
57 |
|
|
|
58 |
def get_model_labels(model_id, example_input):
|
59 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
60 |
payload = {"inputs": example_input, "options": {"use_cache": True}}
|
@@ -63,6 +64,7 @@ def get_model_labels(model_id, example_input):
|
|
63 |
return None
|
64 |
return extract_from_response(response, "label")
|
65 |
|
|
|
66 |
def extract_from_response(data, key):
|
67 |
results = []
|
68 |
|
@@ -80,6 +82,7 @@ def extract_from_response(data, key):
|
|
80 |
|
81 |
return results
|
82 |
|
|
|
83 |
def hf_inference_api(model_id, hf_token, payload):
|
84 |
hf_inference_api_endpoint = os.environ.get(
|
85 |
"HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
|
@@ -94,19 +97,26 @@ def hf_inference_api(model_id, hf_token, payload):
|
|
94 |
try:
|
95 |
output = response.json()
|
96 |
if "error" in output and "Input is too long" in output["error"]:
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
return response.json()
|
102 |
except Exception:
|
103 |
return {"error": response.content}
|
104 |
-
|
|
|
105 |
def preload_hf_inference_api(model_id):
|
106 |
-
payload = {
|
|
|
|
|
|
|
|
|
|
|
107 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
108 |
hf_inference_api(model_id, hf_token, payload)
|
109 |
|
|
|
110 |
def check_model_pipeline(model_id):
|
111 |
try:
|
112 |
task = huggingface_hub.model_info(model_id).pipeline_tag
|
@@ -279,6 +289,7 @@ def check_dataset_features_validity(d_id, config, split):
|
|
279 |
|
280 |
return df, dataset_features
|
281 |
|
|
|
282 |
def select_the_first_string_column(ds):
|
283 |
for feature in ds.features.keys():
|
284 |
if isinstance(ds[0][feature], str):
|
@@ -286,13 +297,17 @@ def select_the_first_string_column(ds):
|
|
286 |
return None
|
287 |
|
288 |
|
289 |
-
def get_example_prediction(
|
|
|
|
|
290 |
# get a sample prediction from the model on the dataset
|
291 |
prediction_input = None
|
292 |
prediction_result = None
|
293 |
try:
|
294 |
# Use the first item to test prediction
|
295 |
-
ds = datasets.load_dataset(
|
|
|
|
|
296 |
if "text" not in ds.features.keys():
|
297 |
# Dataset does not have text column
|
298 |
prediction_input = ds[0][select_the_first_string_column(ds)]
|
@@ -305,10 +320,12 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split,
|
|
305 |
if isinstance(results, dict) and "error" in results.keys():
|
306 |
if "estimated_time" in results.keys():
|
307 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
308 |
-
f"Estimated time: {int(results['estimated_time'])}s. Please try again later."
|
|
|
309 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
310 |
-
f"Inference Error: {results['error']}."
|
311 |
-
|
|
|
312 |
while isinstance(results, list):
|
313 |
if isinstance(results[0], dict):
|
314 |
break
|
@@ -402,4 +419,4 @@ def check_hf_token_validity(hf_token):
|
|
402 |
response = requests.get(AUTH_CHECK_URL, headers=headers)
|
403 |
if response.status_code != 200:
|
404 |
return False
|
405 |
-
return True
|
|
|
14 |
|
15 |
logger = logging.getLogger(__file__)
|
16 |
|
17 |
+
|
18 |
class HuggingFaceInferenceAPIResponse:
|
19 |
def __init__(self, message):
|
20 |
self.message = message
|
|
|
26 |
label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
|
27 |
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
28 |
|
29 |
+
if len(label_keys) == 0: # no labels found
|
30 |
# return everything for post processing
|
31 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
32 |
|
|
|
41 |
labels = dataset_features[label_keys[0]].names
|
42 |
return labels, features, label_keys
|
43 |
except Exception as e:
|
44 |
+
logging.warning(f"Get Labels/Features Failed for dataset: {e}")
|
|
|
|
|
45 |
return None, None, None
|
46 |
|
47 |
+
|
48 |
def check_model_task(model_id):
|
49 |
# check if model is valid on huggingface
|
50 |
try:
|
|
|
55 |
except Exception:
|
56 |
return None
|
57 |
|
58 |
+
|
59 |
def get_model_labels(model_id, example_input):
|
60 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
61 |
payload = {"inputs": example_input, "options": {"use_cache": True}}
|
|
|
64 |
return None
|
65 |
return extract_from_response(response, "label")
|
66 |
|
67 |
+
|
68 |
def extract_from_response(data, key):
|
69 |
results = []
|
70 |
|
|
|
82 |
|
83 |
return results
|
84 |
|
85 |
+
|
86 |
def hf_inference_api(model_id, hf_token, payload):
|
87 |
hf_inference_api_endpoint = os.environ.get(
|
88 |
"HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
|
|
|
97 |
try:
|
98 |
output = response.json()
|
99 |
if "error" in output and "Input is too long" in output["error"]:
|
100 |
+
payload.update({"parameters": {"truncation": True, "max_length": 512}})
|
101 |
+
response = requests.post(url, headers=headers, json=payload)
|
102 |
+
if not hasattr(response, "status_code") or response.status_code != 200:
|
103 |
+
logger.warning(f"Request to inference API returns {response}")
|
104 |
return response.json()
|
105 |
except Exception:
|
106 |
return {"error": response.content}
|
107 |
+
|
108 |
+
|
109 |
def preload_hf_inference_api(model_id):
|
110 |
+
payload = {
|
111 |
+
"inputs": "This is a test",
|
112 |
+
"options": {
|
113 |
+
"use_cache": True,
|
114 |
+
},
|
115 |
+
}
|
116 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
117 |
hf_inference_api(model_id, hf_token, payload)
|
118 |
|
119 |
+
|
120 |
def check_model_pipeline(model_id):
|
121 |
try:
|
122 |
task = huggingface_hub.model_info(model_id).pipeline_tag
|
|
|
289 |
|
290 |
return df, dataset_features
|
291 |
|
292 |
+
|
293 |
def select_the_first_string_column(ds):
|
294 |
for feature in ds.features.keys():
|
295 |
if isinstance(ds[0][feature], str):
|
|
|
297 |
return None
|
298 |
|
299 |
|
300 |
+
def get_example_prediction(
|
301 |
+
model_id, dataset_id, dataset_config, dataset_split, hf_token
|
302 |
+
):
|
303 |
# get a sample prediction from the model on the dataset
|
304 |
prediction_input = None
|
305 |
prediction_result = None
|
306 |
try:
|
307 |
# Use the first item to test prediction
|
308 |
+
ds = datasets.load_dataset(
|
309 |
+
dataset_id, dataset_config, split=dataset_split, trust_remote_code=True
|
310 |
+
)
|
311 |
if "text" not in ds.features.keys():
|
312 |
# Dataset does not have text column
|
313 |
prediction_input = ds[0][select_the_first_string_column(ds)]
|
|
|
320 |
if isinstance(results, dict) and "error" in results.keys():
|
321 |
if "estimated_time" in results.keys():
|
322 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
323 |
+
f"Estimated time: {int(results['estimated_time'])}s. Please try again later."
|
324 |
+
)
|
325 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
326 |
+
f"Inference Error: {results['error']}."
|
327 |
+
)
|
328 |
+
|
329 |
while isinstance(results, list):
|
330 |
if isinstance(results[0], dict):
|
331 |
break
|
|
|
419 |
response = requests.get(AUTH_CHECK_URL, headers=headers)
|
420 |
if response.status_code != 200:
|
421 |
return False
|
422 |
+
return True
|
text_classification_ui_helpers.py
CHANGED
@@ -63,7 +63,7 @@ def get_dataset_splits(dataset_id, dataset_config):
|
|
63 |
splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
|
64 |
return gr.update(choices=splits, value=splits[0], visible=True)
|
65 |
except Exception as e:
|
66 |
-
logger.
|
67 |
return gr.update(visible=False)
|
68 |
|
69 |
def check_dataset(dataset_id):
|
@@ -83,7 +83,7 @@ def check_dataset(dataset_id):
|
|
83 |
""
|
84 |
)
|
85 |
except Exception as e:
|
86 |
-
logger.
|
87 |
if "doesn't exist" in str(e):
|
88 |
gr.Warning(get_dataset_fetch_error_raw(e))
|
89 |
if "forbidden" in str(e).lower(): # GSK-2770
|
@@ -232,7 +232,7 @@ def precheck_model_ds_enable_example_btn(
|
|
232 |
)
|
233 |
except Exception as e:
|
234 |
# Config or split wrong
|
235 |
-
logger.
|
236 |
return (
|
237 |
gr.update(interactive=False),
|
238 |
gr.update(visible=False),
|
@@ -372,30 +372,30 @@ def check_column_mapping_keys_validity(all_mappings):
|
|
372 |
|
373 |
def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
|
374 |
if inference_token == "":
|
375 |
-
logger.
|
376 |
return gr.update(interactive=False)
|
377 |
if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
|
378 |
-
logger.
|
379 |
return gr.update(interactive=False)
|
380 |
|
381 |
all_mappings = read_column_mapping(uid)
|
382 |
if not check_column_mapping_keys_validity(all_mappings):
|
383 |
-
logger.
|
384 |
return gr.update(interactive=False)
|
385 |
|
386 |
if not check_hf_token_validity(inference_token):
|
387 |
-
logger.
|
388 |
return gr.update(interactive=False)
|
389 |
return gr.update(interactive=True)
|
390 |
|
391 |
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
|
392 |
label_mapping = {}
|
393 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
394 |
-
logger.
|
395 |
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
|
396 |
|
397 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
398 |
-
logger.
|
399 |
\nall_mappings: {all_mappings}\nds_features: {ds_features}""")
|
400 |
|
401 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|
|
|
63 |
splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
|
64 |
return gr.update(choices=splits, value=splits[0], visible=True)
|
65 |
except Exception as e:
|
66 |
+
logger.warning(f"Check your dataset {dataset_id} and config {dataset_config}: {e}")
|
67 |
return gr.update(visible=False)
|
68 |
|
69 |
def check_dataset(dataset_id):
|
|
|
83 |
""
|
84 |
)
|
85 |
except Exception as e:
|
86 |
+
logger.warning(f"Check your dataset {dataset_id}: {e}")
|
87 |
if "doesn't exist" in str(e):
|
88 |
gr.Warning(get_dataset_fetch_error_raw(e))
|
89 |
if "forbidden" in str(e).lower(): # GSK-2770
|
|
|
232 |
)
|
233 |
except Exception as e:
|
234 |
# Config or split wrong
|
235 |
+
logger.warning(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
|
236 |
return (
|
237 |
gr.update(interactive=False),
|
238 |
gr.update(visible=False),
|
|
|
372 |
|
373 |
def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
|
374 |
if inference_token == "":
|
375 |
+
logger.warning("Inference API is not enabled")
|
376 |
return gr.update(interactive=False)
|
377 |
if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
|
378 |
+
logger.warning("Model id or dataset id is not selected")
|
379 |
return gr.update(interactive=False)
|
380 |
|
381 |
all_mappings = read_column_mapping(uid)
|
382 |
if not check_column_mapping_keys_validity(all_mappings):
|
383 |
+
logger.warning("Column mapping is not valid")
|
384 |
return gr.update(interactive=False)
|
385 |
|
386 |
if not check_hf_token_validity(inference_token):
|
387 |
+
logger.warning("HF token is not valid")
|
388 |
return gr.update(interactive=False)
|
389 |
return gr.update(interactive=True)
|
390 |
|
391 |
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
|
392 |
label_mapping = {}
|
393 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
394 |
+
logger.warning(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
395 |
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
|
396 |
|
397 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
398 |
+
logger.warning(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
399 |
\nall_mappings: {all_mappings}\nds_features: {ds_features}""")
|
400 |
|
401 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|