davanstrien HF staff commited on
Commit
a776cb5
1 Parent(s): de33a84

raise error if model id not valid

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -11,6 +11,7 @@ import os
11
  import backoff
12
  from functools import lru_cache
13
  from huggingface_hub import list_models, ModelFilter
 
14
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15
 
16
 
@@ -64,17 +65,23 @@ def return_random_sample(k=27):
64
  images = dataset[sample]["image"]
65
  return [resize_image(image).convert("RGB") for image in images]
66
 
 
67
  @lru_cache()
68
  def get_valid_hub_image_classification_model_ids():
69
  models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
70
  return {model.id for model in models}
71
 
 
72
  def predict_subset(model_id, token):
73
- API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
74
- headers = {"Authorization": f"Bearer {token}"}
75
  valid_model_ids = get_valid_hub_image_classification_model_ids()
76
  if model_id not in valid_model_ids:
77
- gr.Error(f"model_id {model_id} is not a valid image classification model id")
 
 
 
 
 
 
78
  @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
79
  def _query(url):
80
  r = requests.post(API_URL, headers=headers, data=url)
 
11
  import backoff
12
  from functools import lru_cache
13
  from huggingface_hub import list_models, ModelFilter
14
+
15
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
16
 
17
 
 
65
  images = dataset[sample]["image"]
66
  return [resize_image(image).convert("RGB") for image in images]
67
 
68
+
69
  @lru_cache()
70
  def get_valid_hub_image_classification_model_ids():
71
  models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
72
  return {model.id for model in models}
73
 
74
+
75
  def predict_subset(model_id, token):
 
 
76
  valid_model_ids = get_valid_hub_image_classification_model_ids()
77
  if model_id not in valid_model_ids:
78
+ raise gr.Error(
79
+ f"model_id {model_id} is not a valid image classification model id"
80
+ )
81
+
82
+ API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
83
+ headers = {"Authorization": f"Bearer {token}"}
84
+
85
  @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
86
  def _query(url):
87
  r = requests.post(API_URL, headers=headers, data=url)