davanstrien HF staff commited on
Commit
ef19caa
1 Parent(s): 9915c6f

Refactor app.py: Import modules, update function parameters, and improve logging

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -1,16 +1,15 @@
1
- import gradio as gr
2
- from httpx import Client
3
- import random
4
  import os
 
 
 
 
5
  import fasttext
6
- from huggingface_hub import hf_hub_download
7
- from typing import Union
8
- from typing import Iterator
9
  from dotenv import load_dotenv
10
- from toolz import groupby, valmap, concat
11
- from statistics import mean
12
- from httpx import Timeout
13
  from huggingface_hub.utils import logging
 
14
 
15
  logger = logging.get_logger(__name__)
16
  load_dotenv()
@@ -24,6 +23,7 @@ headers = {
24
  }
25
  timeout = Timeout(60, read=120)
26
  client = Client(headers=headers, timeout=timeout)
 
27
  # non exhaustive list of columns that might contain text which can be used for language detection
28
  # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
29
  TARGET_COLUMN_NAMES = {
@@ -73,10 +73,10 @@ def get_dataset_info(hub_id: str, config: str | None = None):
73
 
74
 
75
  def get_random_rows(
76
- hub_id,
77
- total_length,
78
- number_of_rows,
79
- max_request_calls,
80
  config="default",
81
  split="train",
82
  ):
@@ -88,8 +88,9 @@ def get_random_rows(
88
  for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
89
  offset = random.randint(0, total_length - rows_per_call)
90
  url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
 
 
91
  response = client.get(url)
92
-
93
  if response.status_code == 200:
94
  data = response.json()
95
  batch_rows = data.get("rows")
@@ -107,10 +108,6 @@ def load_model(repo_id: str) -> fasttext.FastText._FastText:
107
  return fasttext.load_model(model_path)
108
 
109
 
110
- # def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str):
111
- # pass
112
-
113
-
114
  def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
115
  for row in rows:
116
  if isinstance(row, str):
@@ -186,7 +183,8 @@ def predict_language(
186
  config: str | None = None,
187
  split: str | None = None,
188
  max_request_calls: int = 10,
189
- ):
 
190
  is_valid = datasets_server_valid_rows(hub_id)
191
  if not is_valid:
192
  gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
@@ -202,7 +200,7 @@ def predict_language(
202
  logger.info(f"Column names: {column_names}")
203
  if not set(column_names).intersection(TARGET_COLUMN_NAMES):
204
  raise gr.Error(
205
- f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}"
206
  )
207
  for column in TARGET_COLUMN_NAMES:
208
  if column in column_names:
@@ -210,7 +208,12 @@ def predict_language(
210
  logger.info(f"Using column {target_column} for language detection")
211
  break
212
  random_rows = get_random_rows(
213
- hub_id, total_rows_for_split, 1000, max_request_calls, config, split
 
 
 
 
 
214
  )
215
  logger.info(f"Predicting language for {len(random_rows)} rows")
216
  predictions = predict_rows(random_rows, target_column)
 
 
 
 
1
  import os
2
+ import random
3
+ from statistics import mean
4
+ from typing import Iterator, Union
5
+
6
  import fasttext
7
+ import gradio as gr
 
 
8
  from dotenv import load_dotenv
9
+ from httpx import Client, Timeout
10
+ from huggingface_hub import hf_hub_download
 
11
  from huggingface_hub.utils import logging
12
+ from toolz import concat, groupby, valmap
13
 
14
  logger = logging.get_logger(__name__)
15
  load_dotenv()
 
23
  }
24
  timeout = Timeout(60, read=120)
25
  client = Client(headers=headers, timeout=timeout)
26
+ # async_client = AsyncClient(headers=headers, timeout=timeout)
27
  # non exhaustive list of columns that might contain text which can be used for language detection
28
  # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
29
  TARGET_COLUMN_NAMES = {
 
73
 
74
 
75
  def get_random_rows(
76
+ hub_id: str,
77
+ total_length: int,
78
+ number_of_rows: int,
79
+ max_request_calls: int,
80
  config="default",
81
  split="train",
82
  ):
 
88
  for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
89
  offset = random.randint(0, total_length - rows_per_call)
90
  url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
91
+ logger.info(f"Fetching {url}")
92
+ print(url)
93
  response = client.get(url)
 
94
  if response.status_code == 200:
95
  data = response.json()
96
  batch_rows = data.get("rows")
 
108
  return fasttext.load_model(model_path)
109
 
110
 
 
 
 
 
111
  def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
112
  for row in rows:
113
  if isinstance(row, str):
 
183
  config: str | None = None,
184
  split: str | None = None,
185
  max_request_calls: int = 10,
186
+ number_of_rows: int = 1000,
187
+ ) -> dict[str, float | str]:
188
  is_valid = datasets_server_valid_rows(hub_id)
189
  if not is_valid:
190
  gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
 
200
  logger.info(f"Column names: {column_names}")
201
  if not set(column_names).intersection(TARGET_COLUMN_NAMES):
202
  raise gr.Error(
203
+ f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
204
  )
205
  for column in TARGET_COLUMN_NAMES:
206
  if column in column_names:
 
208
  logger.info(f"Using column {target_column} for language detection")
209
  break
210
  random_rows = get_random_rows(
211
+ hub_id,
212
+ total_rows_for_split,
213
+ number_of_rows,
214
+ max_request_calls,
215
+ config,
216
+ split,
217
  )
218
  logger.info(f"Predicting language for {len(random_rows)} rows")
219
  predictions = predict_rows(random_rows, target_column)