Jonatan Asketorp commited on
Commit
b3f2672
1 Parent(s): 457d9a8

Add onnx filter

Browse files
Files changed (1) hide show
  1. background_task.py +15 -4
background_task.py CHANGED
@@ -191,7 +191,17 @@ def match(model1, model2):
191
  print(f"Match {model1_id} against {model2_id} ended.")
192
 
193
 
194
- def get_models_list(filter_bad_models) -> list:
 
 
 
 
 
 
 
 
 
 
195
  """
196
  Get the list of models from the hub and the ELO file.
197
 
@@ -200,14 +210,15 @@ def get_models_list(filter_bad_models) -> list:
200
  models = []
201
  models_ids = []
202
  data = pd.read_csv(os.path.join(DATASET_REPO_URL, "resolve", "main", ELO_FILENAME))
203
- models_on_hub = api.list_models(filter=["reinforcement-learning", "ml-agents", "ML-Agents-SoccerTwos"])
 
204
  for i, row in data.iterrows():
205
  model_id = row["author"] + "/" + row["model"]
206
- if model_id in filter_bad_models:
207
  continue
208
  models.append(Model(row["author"], row["model"], row["elo"], row["games_played"]))
209
  models_ids.append(model_id)
210
- for model in models_on_hub:
211
  if model.modelId in filter_bad_models:
212
  continue
213
  author, name = model.modelId.split("/")[0], model.modelId.split("/")[1]
 
191
  print(f"Match {model1_id} against {model2_id} ended.")
192
 
193
 
194
+ def check_for_onnx_file(model_info: ModelInfo) -> bool:
195
+ """
196
+ Checks if the model contains a `.onnx` file.
197
+ """
198
+ for repo_file in model_info.siblings:
199
+ if repo_file.rfilename.endswith(".onnx"):
200
+ return True
201
+ return False
202
+
203
+
204
+ def get_models_list(filter_bad_models):
205
  """
206
  Get the list of models from the hub and the ELO file.
207
 
 
210
  models = []
211
  models_ids = []
212
  data = pd.read_csv(os.path.join(DATASET_REPO_URL, "resolve", "main", ELO_FILENAME))
213
+ models_with_onnx_on_hub = filter(check_for_onnx_file, api.list_models(filter=["reinforcement-learning", "ml-agents", "ML-Agents-SoccerTwos"]))
214
+ model_ids_with_onnx = {model.modelId for model in models_with_onnx_on_hub}
215
  for i, row in data.iterrows():
216
  model_id = row["author"] + "/" + row["model"]
217
+ if model_id in filter_bad_models or model_id not in model_ids_with_onnx:
218
  continue
219
  models.append(Model(row["author"], row["model"], row["elo"], row["games_played"]))
220
  models_ids.append(model_id)
221
+ for model in models_with_onnx_on_hub:
222
  if model.modelId in filter_bad_models:
223
  continue
224
  author, name = model.modelId.split("/")[0], model.modelId.split("/")[1]