qgallouedec HF staff commited on
Commit
a2f004f
1 Parent(s): c67b794
Files changed (1) hide show
  1. src/backend.py +15 -18
src/backend.py CHANGED
@@ -133,33 +133,30 @@ def pattern_match(patterns, source_list):
133
 
134
  def _backend_routine():
135
  # List only the text classification models
136
- rl_models = [(model.modelId, model.sha) for model in API.list_models(filter=["reinforcement-learning"])]
137
  logger.info(f"Found {len(rl_models)} RL models")
138
- dataset = load_dataset(
139
- RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks"
140
- )
141
- evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset]
142
- pending_models = list(set(rl_models) - set(evaluated_models))
143
- pending_and_compatible_models = []
144
- for repo_id, sha in pending_models:
145
- try:
146
- siblings = API.model_info(repo_id, revision="main").siblings
147
- except Exception:
148
- continue
149
- filenames = [sib.rfilename for sib in siblings]
150
  if "agent.pt" in filenames:
151
- pending_and_compatible_models.append((repo_id, sha))
152
 
153
- logger.info(f"Found {len(pending_and_compatible_models)} compatible pending models")
 
 
 
 
 
154
 
155
- if len(pending_and_compatible_models) == 0:
156
  return None
157
 
158
  # Shuffle the dataset
159
- random.shuffle(pending_and_compatible_models)
160
 
161
  # Select a random model
162
- repo_id, sha = pending_and_compatible_models.pop()
163
  user_id, model_id = repo_id.split("/")
164
  row = {"model_id": model_id, "user_id": user_id, "sha": sha}
165
 
 
133
 
134
  def _backend_routine():
135
  # List only the text classification models
136
+ rl_models = API.list_models(filter=["reinforcement-learning"])
137
  logger.info(f"Found {len(rl_models)} RL models")
138
+
139
+ compatible_models = []
140
+ for model in rl_models:
141
+ filenames = [sib.rfilename for sib in model.siblings]
 
 
 
 
 
 
 
 
142
  if "agent.pt" in filenames:
143
+ compatible_models.append((model.modelId, model.sha))
144
 
145
+ logger.info(f"Found {len(compatible_models)} compatible models")
146
+
147
+ dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks")
148
+ evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset]
149
+ pending_models = list(set(compatible_models) - set(evaluated_models))
150
+ logger.info(f"Found {len(pending_models)} pending models")
151
 
152
+ if len(pending_models) == 0:
153
  return None
154
 
155
  # Shuffle the dataset
156
+ random.shuffle(pending_models)
157
 
158
  # Select a random model
159
+ repo_id, sha = pending_models.pop()
160
  user_id, model_id = repo_id.split("/")
161
  row = {"model_id": model_id, "user_id": user_id, "sha": sha}
162