davanstrien HF staff commited on
Commit
bc828c5
1 Parent(s): a2d40e3

Add language parsing functionality and update dependencies

Browse files
Files changed (3) hide show
  1. main.py +33 -10
  2. requirements.in +2 -1
  3. requirements.txt +2 -0
main.py CHANGED
@@ -15,6 +15,7 @@ from starlette.responses import RedirectResponse
15
  from cashews import cache
16
  from datetime import timedelta
17
  import logging
 
18
 
19
  cache.setup("mem://")
20
 
@@ -93,6 +94,19 @@ async def get_dataset_info(hub_id: str, config: str | None = None):
93
  return resp.json()
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  async def get_random_rows(
97
  hub_id: str,
98
  total_length: int,
@@ -110,15 +124,8 @@ async def get_random_rows(
110
  offset = random.randint(0, total_length - rows_per_call)
111
  url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
112
  logger.info(f"Fetching {url}")
113
- print(url)
114
- response = await async_client.get(url)
115
- if response.status_code == 200:
116
- data = response.json()
117
- batch_rows = data.get("rows")
118
- rows.extend(batch_rows)
119
- else:
120
- print(f"Failed to fetch data: {response.status_code}")
121
- print(url)
122
  if len(rows) >= number_of_rows:
123
  break
124
  return [row.get("row") for row in rows]
@@ -181,6 +188,17 @@ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
181
  return {k for k, v in counts_dict.items() if v >= threshold}
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
184
  def predict_rows(
185
  rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
186
  ):
@@ -196,8 +214,13 @@ def predict_rows(
196
  langues_counts, threshold_percent=language_threshold_percent
197
  )
198
  filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
 
 
 
 
199
  default_data = {
200
- "predictions": dict(valmap(get_mean_score, filtered_dict)),
 
201
  "hub_id": "hub_id",
202
  "config": "config",
203
  }
 
15
  from cashews import cache
16
  from datetime import timedelta
17
  import logging
18
+ from iso639 import Lang
19
 
20
  cache.setup("mem://")
21
 
 
94
  return resp.json()
95
 
96
 
97
+ @cache(ttl=timedelta(minutes=5))
98
+ async def fetch_rows(url: str) -> list[dict]:
99
+ response = await async_client.get(url)
100
+ if response.status_code == 200:
101
+ data = response.json()
102
+ return data.get("rows")
103
+ else:
104
+ print(f"Failed to fetch data: {response.status_code}")
105
+ print(url)
106
+ return []
107
+
108
+
109
+ # Function to get random rows from the dataset
110
  async def get_random_rows(
111
  hub_id: str,
112
  total_length: int,
 
124
  offset = random.randint(0, total_length - rows_per_call)
125
  url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
126
  logger.info(f"Fetching {url}")
127
+ batch_rows = await fetch_rows(url)
128
+ rows.extend(batch_rows)
 
 
 
 
 
 
 
129
  if len(rows) >= number_of_rows:
130
  break
131
  return [row.get("row") for row in rows]
 
188
  return {k for k, v in counts_dict.items() if v >= threshold}
189
 
190
 
191
+ def try_parse_language(lang: str) -> str | None:
192
+ try:
193
+ split = lang.split("_")
194
+ lang = split[0]
195
+ lang = Lang(lang)
196
+ return lang.pt1
197
+ except Exception as e:
198
+ logger.error(f"Failed to parse language {lang}: {e}")
199
+ return None
200
+
201
+
202
  def predict_rows(
203
  rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
204
  ):
 
214
  langues_counts, threshold_percent=language_threshold_percent
215
  )
216
  filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
217
+ raw_model_prediction_summary = dict(valmap(get_mean_score, filtered_dict))
218
+ parsed_langs = {
219
+ try_parse_language(k): v for k, v in raw_model_prediction_summary.items()
220
+ }
221
  default_data = {
222
+ "language_prediction_summary": parsed_langs,
223
+ "raw_model_prediction_summary": raw_model_prediction_summary,
224
  "hub_id": "hub_id",
225
  "config": "config",
226
  }
requirements.in CHANGED
@@ -8,4 +8,5 @@ huggingface_hub
8
  python-dotenv
9
  rich
10
  toolz
11
- uvicorn[standard]
 
 
8
  python-dotenv
9
  rich
10
  toolz
11
+ uvicorn[standard]
12
+ iso639-lang
requirements.txt CHANGED
@@ -51,6 +51,8 @@ idna==3.6
51
  # anyio
52
  # httpx
53
  # requests
 
 
54
  markdown-it-py==3.0.0
55
  # via rich
56
  mdurl==0.1.2
 
51
  # anyio
52
  # httpx
53
  # requests
54
+ iso639-lang==2.2.2
55
+ # via -r requirements.in
56
  markdown-it-py==3.0.0
57
  # via rich
58
  mdurl==0.1.2