davanstrien HF staff commited on
Commit
a2d40e3
1 Parent(s): 2995c02

Refactor predict_rows function to include raw predictions

Browse files

This commit refactors the predict_rows function in main.py to include an optional parameter, return_raw_predictions, which when set to True, returns the raw predictions along with the mean scores. This change improves the flexibility and usefulness of the function.

Files changed (1) hide show
  1. main.py +15 -2
main.py CHANGED
@@ -181,7 +181,9 @@ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
181
  return {k for k, v in counts_dict.items() if v >= threshold}
182
 
183
 
184
- def predict_rows(rows, target_column, language_threshold_percent=0.2):
 
 
185
  rows = (row.get(target_column) for row in rows)
186
  rows = (row for row in rows if row is not None)
187
  rows = list(yield_clean_rows(rows))
@@ -194,9 +196,20 @@ def predict_rows(rows, target_column, language_threshold_percent=0.2):
194
  langues_counts, threshold_percent=language_threshold_percent
195
  )
196
  filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
197
- return {
198
  "predictions": dict(valmap(get_mean_score, filtered_dict)),
 
 
199
  }
 
 
 
 
 
 
 
 
 
200
 
201
 
202
  @app.get("/", include_in_schema=False)
 
181
  return {k for k, v in counts_dict.items() if v >= threshold}
182
 
183
 
184
+ def predict_rows(
185
+ rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
186
+ ):
187
  rows = (row.get(target_column) for row in rows)
188
  rows = (row for row in rows if row is not None)
189
  rows = list(yield_clean_rows(rows))
 
196
  langues_counts, threshold_percent=language_threshold_percent
197
  )
198
  filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
199
+ default_data = {
200
  "predictions": dict(valmap(get_mean_score, filtered_dict)),
201
+ "hub_id": "hub_id",
202
+ "config": "config",
203
  }
204
+ if return_raw_predictions:
205
+ default_data["raw_predictions"] = predictions
206
+ return default_data
207
+
208
+
209
+ # @app.get("/", response_class=HTMLResponse)
210
+ # async def read_index():
211
+ # html_content = Path("index.html").read_text()
212
+ # return HTMLResponse(content=html_content)
213
 
214
 
215
  @app.get("/", include_in_schema=False)