inoki-giskard commited on
Commit
fc361e6
1 Parent(s): e71d7e4

Format text classification and use warning

Browse files
run_jobs.py CHANGED
@@ -69,7 +69,7 @@ def prepare_env_and_get_command(
69
  )
70
  logger.info(f"Using {executable} as executable")
71
  except Exception as e:
72
- logger.warn(f"Create env failed due to {e}, using the current env as fallback.")
73
  executable = "giskard_scanner"
74
 
75
  command = [
 
69
  )
70
  logger.info(f"Using {executable} as executable")
71
  except Exception as e:
72
+ logger.warning(f"Create env failed due to {e}, using the current env as fallback.")
73
  executable = "giskard_scanner"
74
 
75
  command = [
text_classification.py CHANGED
@@ -14,6 +14,7 @@ AUTH_CHECK_URL = "https://huggingface.co/api/whoami-v2"
14
 
15
  logger = logging.getLogger(__file__)
16
 
 
17
  class HuggingFaceInferenceAPIResponse:
18
  def __init__(self, message):
19
  self.message = message
@@ -25,7 +26,7 @@ def get_labels_and_features_from_dataset(ds):
25
  label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
26
  features = [f for f in dataset_features.keys() if not f.startswith("label")]
27
 
28
- if len(label_keys) == 0: # no labels found
29
  # return everything for post processing
30
  return list(dataset_features.keys()), list(dataset_features.keys()), None
31
 
@@ -40,11 +41,10 @@ def get_labels_and_features_from_dataset(ds):
40
  labels = dataset_features[label_keys[0]].names
41
  return labels, features, label_keys
42
  except Exception as e:
43
- logging.warning(
44
- f"Get Labels/Features Failed for dataset: {e}"
45
- )
46
  return None, None, None
47
 
 
48
  def check_model_task(model_id):
49
  # check if model is valid on huggingface
50
  try:
@@ -55,6 +55,7 @@ def check_model_task(model_id):
55
  except Exception:
56
  return None
57
 
 
58
  def get_model_labels(model_id, example_input):
59
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
60
  payload = {"inputs": example_input, "options": {"use_cache": True}}
@@ -63,6 +64,7 @@ def get_model_labels(model_id, example_input):
63
  return None
64
  return extract_from_response(response, "label")
65
 
 
66
  def extract_from_response(data, key):
67
  results = []
68
 
@@ -80,6 +82,7 @@ def extract_from_response(data, key):
80
 
81
  return results
82
 
 
83
  def hf_inference_api(model_id, hf_token, payload):
84
  hf_inference_api_endpoint = os.environ.get(
85
  "HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
@@ -94,19 +97,26 @@ def hf_inference_api(model_id, hf_token, payload):
94
  try:
95
  output = response.json()
96
  if "error" in output and "Input is too long" in output["error"]:
97
- payload.update({"parameters": {"truncation": True, "max_length": 512}})
98
- response = requests.post(url, headers=headers, json=payload)
99
- if not hasattr(response, "status_code") or response.status_code != 200:
100
- logger.warning(f"Request to inference API returns {response}")
101
  return response.json()
102
  except Exception:
103
  return {"error": response.content}
104
-
 
105
  def preload_hf_inference_api(model_id):
106
- payload = {"inputs": "This is a test", "options": {"use_cache": True, }}
 
 
 
 
 
107
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
108
  hf_inference_api(model_id, hf_token, payload)
109
 
 
110
  def check_model_pipeline(model_id):
111
  try:
112
  task = huggingface_hub.model_info(model_id).pipeline_tag
@@ -279,6 +289,7 @@ def check_dataset_features_validity(d_id, config, split):
279
 
280
  return df, dataset_features
281
 
 
282
  def select_the_first_string_column(ds):
283
  for feature in ds.features.keys():
284
  if isinstance(ds[0][feature], str):
@@ -286,13 +297,17 @@ def select_the_first_string_column(ds):
286
  return None
287
 
288
 
289
- def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split, hf_token):
 
 
290
  # get a sample prediction from the model on the dataset
291
  prediction_input = None
292
  prediction_result = None
293
  try:
294
  # Use the first item to test prediction
295
- ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
 
 
296
  if "text" not in ds.features.keys():
297
  # Dataset does not have text column
298
  prediction_input = ds[0][select_the_first_string_column(ds)]
@@ -305,10 +320,12 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split,
305
  if isinstance(results, dict) and "error" in results.keys():
306
  if "estimated_time" in results.keys():
307
  return prediction_input, HuggingFaceInferenceAPIResponse(
308
- f"Estimated time: {int(results['estimated_time'])}s. Please try again later.")
 
309
  return prediction_input, HuggingFaceInferenceAPIResponse(
310
- f"Inference Error: {results['error']}.")
311
-
 
312
  while isinstance(results, list):
313
  if isinstance(results[0], dict):
314
  break
@@ -402,4 +419,4 @@ def check_hf_token_validity(hf_token):
402
  response = requests.get(AUTH_CHECK_URL, headers=headers)
403
  if response.status_code != 200:
404
  return False
405
- return True
 
14
 
15
  logger = logging.getLogger(__file__)
16
 
17
+
18
  class HuggingFaceInferenceAPIResponse:
19
  def __init__(self, message):
20
  self.message = message
 
26
  label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
27
  features = [f for f in dataset_features.keys() if not f.startswith("label")]
28
 
29
+ if len(label_keys) == 0: # no labels found
30
  # return everything for post processing
31
  return list(dataset_features.keys()), list(dataset_features.keys()), None
32
 
 
41
  labels = dataset_features[label_keys[0]].names
42
  return labels, features, label_keys
43
  except Exception as e:
44
+ logging.warning(f"Get Labels/Features Failed for dataset: {e}")
 
 
45
  return None, None, None
46
 
47
+
48
  def check_model_task(model_id):
49
  # check if model is valid on huggingface
50
  try:
 
55
  except Exception:
56
  return None
57
 
58
+
59
  def get_model_labels(model_id, example_input):
60
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
61
  payload = {"inputs": example_input, "options": {"use_cache": True}}
 
64
  return None
65
  return extract_from_response(response, "label")
66
 
67
+
68
  def extract_from_response(data, key):
69
  results = []
70
 
 
82
 
83
  return results
84
 
85
+
86
  def hf_inference_api(model_id, hf_token, payload):
87
  hf_inference_api_endpoint = os.environ.get(
88
  "HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
 
97
  try:
98
  output = response.json()
99
  if "error" in output and "Input is too long" in output["error"]:
100
+ payload.update({"parameters": {"truncation": True, "max_length": 512}})
101
+ response = requests.post(url, headers=headers, json=payload)
102
+ if not hasattr(response, "status_code") or response.status_code != 200:
103
+ logger.warning(f"Request to inference API returns {response}")
104
  return response.json()
105
  except Exception:
106
  return {"error": response.content}
107
+
108
+
109
  def preload_hf_inference_api(model_id):
110
+ payload = {
111
+ "inputs": "This is a test",
112
+ "options": {
113
+ "use_cache": True,
114
+ },
115
+ }
116
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
117
  hf_inference_api(model_id, hf_token, payload)
118
 
119
+
120
  def check_model_pipeline(model_id):
121
  try:
122
  task = huggingface_hub.model_info(model_id).pipeline_tag
 
289
 
290
  return df, dataset_features
291
 
292
+
293
  def select_the_first_string_column(ds):
294
  for feature in ds.features.keys():
295
  if isinstance(ds[0][feature], str):
 
297
  return None
298
 
299
 
300
+ def get_example_prediction(
301
+ model_id, dataset_id, dataset_config, dataset_split, hf_token
302
+ ):
303
  # get a sample prediction from the model on the dataset
304
  prediction_input = None
305
  prediction_result = None
306
  try:
307
  # Use the first item to test prediction
308
+ ds = datasets.load_dataset(
309
+ dataset_id, dataset_config, split=dataset_split, trust_remote_code=True
310
+ )
311
  if "text" not in ds.features.keys():
312
  # Dataset does not have text column
313
  prediction_input = ds[0][select_the_first_string_column(ds)]
 
320
  if isinstance(results, dict) and "error" in results.keys():
321
  if "estimated_time" in results.keys():
322
  return prediction_input, HuggingFaceInferenceAPIResponse(
323
+ f"Estimated time: {int(results['estimated_time'])}s. Please try again later."
324
+ )
325
  return prediction_input, HuggingFaceInferenceAPIResponse(
326
+ f"Inference Error: {results['error']}."
327
+ )
328
+
329
  while isinstance(results, list):
330
  if isinstance(results[0], dict):
331
  break
 
419
  response = requests.get(AUTH_CHECK_URL, headers=headers)
420
  if response.status_code != 200:
421
  return False
422
+ return True
text_classification_ui_helpers.py CHANGED
@@ -63,7 +63,7 @@ def get_dataset_splits(dataset_id, dataset_config):
63
  splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
64
  return gr.update(choices=splits, value=splits[0], visible=True)
65
  except Exception as e:
66
- logger.warn(f"Check your dataset {dataset_id} and config {dataset_config}: {e}")
67
  return gr.update(visible=False)
68
 
69
  def check_dataset(dataset_id):
@@ -83,7 +83,7 @@ def check_dataset(dataset_id):
83
  ""
84
  )
85
  except Exception as e:
86
- logger.warn(f"Check your dataset {dataset_id}: {e}")
87
  if "doesn't exist" in str(e):
88
  gr.Warning(get_dataset_fetch_error_raw(e))
89
  if "forbidden" in str(e).lower(): # GSK-2770
@@ -232,7 +232,7 @@ def precheck_model_ds_enable_example_btn(
232
  )
233
  except Exception as e:
234
  # Config or split wrong
235
- logger.warn(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
236
  return (
237
  gr.update(interactive=False),
238
  gr.update(visible=False),
@@ -372,30 +372,30 @@ def check_column_mapping_keys_validity(all_mappings):
372
 
373
  def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
374
  if inference_token == "":
375
- logger.warn("Inference API is not enabled")
376
  return gr.update(interactive=False)
377
  if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
378
- logger.warn("Model id or dataset id is not selected")
379
  return gr.update(interactive=False)
380
 
381
  all_mappings = read_column_mapping(uid)
382
  if not check_column_mapping_keys_validity(all_mappings):
383
- logger.warn("Column mapping is not valid")
384
  return gr.update(interactive=False)
385
 
386
  if not check_hf_token_validity(inference_token):
387
- logger.warn("HF token is not valid")
388
  return gr.update(interactive=False)
389
  return gr.update(interactive=True)
390
 
391
  def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
392
  label_mapping = {}
393
  if len(all_mappings["labels"].keys()) != len(ds_labels):
394
- logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
395
  \nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
396
 
397
  if len(all_mappings["features"].keys()) != len(ds_features):
398
- logger.warn(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
399
  \nall_mappings: {all_mappings}\nds_features: {ds_features}""")
400
 
401
  for i, label in zip(range(len(ds_labels)), ds_labels):
 
63
  splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
64
  return gr.update(choices=splits, value=splits[0], visible=True)
65
  except Exception as e:
66
+ logger.warning(f"Check your dataset {dataset_id} and config {dataset_config}: {e}")
67
  return gr.update(visible=False)
68
 
69
  def check_dataset(dataset_id):
 
83
  ""
84
  )
85
  except Exception as e:
86
+ logger.warning(f"Check your dataset {dataset_id}: {e}")
87
  if "doesn't exist" in str(e):
88
  gr.Warning(get_dataset_fetch_error_raw(e))
89
  if "forbidden" in str(e).lower(): # GSK-2770
 
232
  )
233
  except Exception as e:
234
  # Config or split wrong
235
+ logger.warning(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
236
  return (
237
  gr.update(interactive=False),
238
  gr.update(visible=False),
 
372
 
373
  def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
374
  if inference_token == "":
375
+ logger.warning("Inference API is not enabled")
376
  return gr.update(interactive=False)
377
  if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
378
+ logger.warning("Model id or dataset id is not selected")
379
  return gr.update(interactive=False)
380
 
381
  all_mappings = read_column_mapping(uid)
382
  if not check_column_mapping_keys_validity(all_mappings):
383
+ logger.warning("Column mapping is not valid")
384
  return gr.update(interactive=False)
385
 
386
  if not check_hf_token_validity(inference_token):
387
+ logger.warning("HF token is not valid")
388
  return gr.update(interactive=False)
389
  return gr.update(interactive=True)
390
 
391
  def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
392
  label_mapping = {}
393
  if len(all_mappings["labels"].keys()) != len(ds_labels):
394
+ logger.warning(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
395
  \nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
396
 
397
  if len(all_mappings["features"].keys()) != len(ds_features):
398
+ logger.warning(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
399
  \nall_mappings: {all_mappings}\nds_features: {ds_features}""")
400
 
401
  for i, label in zip(range(len(ds_labels)), ds_labels):