ZeroCommand commited on
Commit
6eb5802
·
1 Parent(s): d0be156

handle hf api response with custom class

Browse files
text_classification.py CHANGED
@@ -7,13 +7,16 @@ import pandas as pd
7
  from transformers import pipeline
8
  import requests
9
  import os
10
- import time
11
 
12
  logger = logging.getLogger(__name__)
13
  HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
14
 
15
  logger = logging.getLogger(__file__)
16
 
 
 
 
 
17
 
18
  def get_labels_and_features_from_dataset(ds):
19
  try:
@@ -286,13 +289,12 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
286
  payload = {"inputs": prediction_input, "options": {"use_cache": True}}
287
  results = hf_inference_api(model_id, hf_token, payload)
288
 
289
- if isinstance(results, dict) and "estimated_time" in results.keys():
290
- # return the estimated time for the inference api to load
291
- # cast the float to int to be concise
292
- return prediction_input, str(f"{int(results['estimated_time'])}s")
293
-
294
  if isinstance(results, dict) and "error" in results.keys():
295
- raise ValueError(results["error"])
 
 
 
 
296
 
297
  while isinstance(results, list):
298
  if isinstance(results[0], dict):
@@ -303,7 +305,8 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
303
  }
304
  except Exception as e:
305
  # inference api prediction failed, show the error message
306
- return prediction_input, e
 
307
 
308
  return prediction_input, prediction_result
309
 
 
7
  from transformers import pipeline
8
  import requests
9
  import os
 
10
 
11
  logger = logging.getLogger(__name__)
12
  HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
13
 
14
  logger = logging.getLogger(__file__)
15
 
16
+ class HuggingFaceInferenceAPIResponse:
17
+ def __init__(self, message):
18
+ self.message = message
19
+
20
 
21
  def get_labels_and_features_from_dataset(ds):
22
  try:
 
289
  payload = {"inputs": prediction_input, "options": {"use_cache": True}}
290
  results = hf_inference_api(model_id, hf_token, payload)
291
 
 
 
 
 
 
292
  if isinstance(results, dict) and "error" in results.keys():
293
+ if "estimated_time" in results.keys():
294
+ return prediction_input, HuggingFaceInferenceAPIResponse(
295
+ f"Estimated time: {int(results['estimated_time'])}s. Please try again later.")
296
+ return prediction_input, HuggingFaceInferenceAPIResponse(
297
+ f"Inference Error: {results['error']}.")
298
 
299
  while isinstance(results, list):
300
  if isinstance(results[0], dict):
 
305
  }
306
  except Exception as e:
307
  # inference api prediction failed, show the error message
308
+ logger.error(f"Get example prediction failed {e}")
309
+ return prediction_input, None
310
 
311
  return prediction_input, prediction_result
312
 
text_classification_ui_helpers.py CHANGED
@@ -15,6 +15,7 @@ from text_classification import (
15
  preload_hf_inference_api,
16
  get_example_prediction,
17
  get_labels_and_features_from_dataset,
 
18
  )
19
  from wordings import (
20
  CHECK_CONFIG_OR_SPLIT_RAW,
@@ -213,24 +214,23 @@ def align_columns_and_show_prediction(
213
  model_id, dataset_id, dataset_config, dataset_split
214
  )
215
 
216
- if isinstance(prediction_response, str):
217
  return (
218
  gr.update(visible=False),
219
  gr.update(visible=False),
220
  gr.update(visible=False, open=False),
221
  gr.update(interactive=False),
222
- f"Hugging Face Inference API is loading your model, estimation time {prediction_response}. Please validate again later.",
223
  *dropdown_placement,
224
  )
225
 
226
- if isinstance(prediction_response, Exception):
227
- gr.Warning(f"Inference API loading error: {prediction_response}. Please check your model or Hugging Face token.")
228
  return (
229
  gr.update(visible=False),
230
  gr.update(visible=False),
231
  gr.update(visible=False, open=False),
232
  gr.update(interactive=False),
233
- "",
234
  *dropdown_placement,
235
  )
236
 
 
15
  preload_hf_inference_api,
16
  get_example_prediction,
17
  get_labels_and_features_from_dataset,
18
+ HuggingFaceInferenceAPIResponse,
19
  )
20
  from wordings import (
21
  CHECK_CONFIG_OR_SPLIT_RAW,
 
214
  model_id, dataset_id, dataset_config, dataset_split
215
  )
216
 
217
+ if prediction_input is None or prediction_response is None:
218
  return (
219
  gr.update(visible=False),
220
  gr.update(visible=False),
221
  gr.update(visible=False, open=False),
222
  gr.update(interactive=False),
223
+ "",
224
  *dropdown_placement,
225
  )
226
 
227
+ if isinstance(prediction_response, HuggingFaceInferenceAPIResponse):
 
228
  return (
229
  gr.update(visible=False),
230
  gr.update(visible=False),
231
  gr.update(visible=False, open=False),
232
  gr.update(interactive=False),
233
+ f"Hugging Face Inference API is loading your model. {prediction_response.message}",
234
  *dropdown_placement,
235
  )
236