davanstrien HF staff commited on
Commit
4a3355b
1 Parent(s): 88bf26f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -48
app.py CHANGED
@@ -2,7 +2,7 @@ from dotenv import load_dotenv
2
  import os
3
 
4
  import pandas as pd
5
- from httpx import AsyncClient
6
  from huggingface_hub import dataset_info
7
  from huggingface_hub.utils import logging
8
  from functools import lru_cache
@@ -10,28 +10,23 @@ from tqdm.contrib.concurrent import thread_map
10
  from huggingface_hub import HfApi
11
  from rich import print
12
  import gradio as gr
13
- import asyncio
14
 
15
 
16
- async def check_dataset_has_non_default_file(hub_id):
17
- try:
18
- info = await dataset_info(hub_id)
19
- if files := info.siblings:
20
- file_names = [f.rfilename for f in files]
21
- files = [f for f in file_names if not f.startswith(".") or f == "README.md"]
22
- return len(files) >= 1
23
- return False
24
- except Exception as e:
25
- logger.error(f"Failed to get siblings for {hub_id}: {e}")
26
- return False
27
 
28
 
29
- async def datasets_server_valid_rows(hub_id: str, async_client: AsyncClient):
30
  try:
31
- resp = await async_client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
32
  return resp.json()["viewer"]
33
- except Exception as e:
34
- logger.error(f"Failed to get is-valid for {hub_id}: {e}")
35
  return None
36
 
37
 
@@ -48,18 +43,20 @@ BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
48
 
49
  logger = logging.get_logger(__name__)
50
  headers = {
51
- "authorization": f"Bearer {HF_TOKEN}",
52
  "user-agent": USER_AGENT,
53
  }
54
- async_client = AsyncClient(headers=headers)
 
55
  api = HfApi(token=HF_TOKEN)
56
 
57
 
58
- async def get_first_config_and_split_name(hub_id: str, async_client: AsyncClient):
59
  try:
60
- resp = await async_client.get(
61
  f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
62
  )
 
63
  data = resp.json()
64
  return data["splits"][0]["config"], data["splits"][0]["split"]
65
  except Exception as e:
@@ -67,29 +64,57 @@ async def get_first_config_and_split_name(hub_id: str, async_client: AsyncClient
67
  return None
68
 
69
 
70
- async def get_dataset_info(hub_id: str, config: str | None = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  if config is None:
72
- config = await get_first_config_and_split_name(hub_id, async_client)
73
  if config is None:
74
  return None
75
  else:
76
  config = config[0]
77
- resp = await async_client.get(
78
  f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
79
  )
80
  resp.raise_for_status()
81
  return resp.json()
82
 
83
 
84
- async def dataset_with_info(dataset):
85
  try:
86
- if info := await get_dataset_info(dataset.id):
87
  columns = info.get("dataset_info", {}).get("features", {})
88
  if columns is not None:
89
  return {
90
  "hub_id": dataset.id,
91
  "column_names": list(columns.keys()),
92
  "columns": columns,
 
 
93
  "likes": dataset.likes,
94
  "downloads": dataset.downloads,
95
  "created_at": dataset.created_at,
@@ -100,34 +125,36 @@ async def dataset_with_info(dataset):
100
  return None
101
 
102
 
103
- async def return_dataset_with_non_default_files(dataset):
104
- return dataset if await check_dataset_has_non_default_file(dataset.id) else None
105
 
106
 
107
  @lru_cache(maxsize=100)
108
- async def prep_data():
109
- datasets = list(api.list_datasets(limit=200, sort="createdAt", direction=-1))
110
  print(f"Found {len(datasets)} datasets.")
111
-
112
- valid_datasets = await asyncio.gather(*[return_dataset_with_non_default_files(dataset) for dataset in datasets])
113
- valid_datasets = [x for x in valid_datasets if x is not None]
114
- print(f"Found {len(valid_datasets)} datasets with non-default files.")
115
-
116
- has_server = await asyncio.gather(*[datasets_server_valid_rows(dataset.id, async_client) for dataset in valid_datasets])
117
- datasets_with_server = [dataset for dataset, server_valid in zip(valid_datasets, has_server) if server_valid]
 
 
 
 
118
  print(f"Found {len(datasets_with_server)} datasets with server.")
119
-
120
- datasets_server_data = await asyncio.gather(*[dataset_with_info(dataset) for dataset in datasets_with_server])
121
- datasets_server_data = [data for data in datasets_server_data if data is not None]
122
  print(f"Found {len(datasets_server_data)} datasets with server data.")
123
  print(datasets_server_data[0])
124
-
125
  return datasets_server_data
126
 
127
 
128
  def filter_columns(datasets_server_data, columns=None):
129
  if columns is not None:
130
  clean = []
 
131
  for dataset in datasets_server_data:
132
  if dataset is not None:
133
  target_column = dataset.get("columns", [])
@@ -139,8 +166,17 @@ def filter_columns(datasets_server_data, columns=None):
139
  return datasets_server_data
140
 
141
 
142
- async def predict(columns_to_filter):
143
- datasets_server_data = await prep_data()
 
 
 
 
 
 
 
 
 
144
  columns_to_filter = columns_to_filter.split(",")
145
  columns_to_filter = [x.strip() for x in columns_to_filter]
146
  filtered = filter_columns(
@@ -152,11 +188,6 @@ async def predict(columns_to_filter):
152
  return df
153
 
154
 
155
- def render_model_hub_link(hub_id):
156
- link = f"https://huggingface.co/datasets/{hub_id}"
157
- return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>'
158
-
159
-
160
  with gr.Blocks() as demo:
161
  gr.Markdown("# Search Hugging Face datasets by column names (POC)")
162
  gr.Markdown(
 
2
  import os
3
 
4
  import pandas as pd
5
+ from httpx import Client
6
  from huggingface_hub import dataset_info
7
  from huggingface_hub.utils import logging
8
  from functools import lru_cache
 
10
  from huggingface_hub import HfApi
11
  from rich import print
12
  import gradio as gr
 
13
 
14
 
15
+ def check_dataset_has_non_default_file(hub_id):
16
+ info = dataset_info(hub_id)
17
+ if files := info.siblings:
18
+ file_names = [f.rfilename for f in files]
19
+ files = [f for f in file_names if not f.startswith(".") or f == "README.md"]
20
+ return len(files) >= 1
21
+ return False
 
 
 
 
22
 
23
 
24
+ def datasets_server_valid_rows(hub_id: str):
25
  try:
26
+ resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
27
  return resp.json()["viewer"]
28
+ except Exception:
29
+ # logger.error(f"Failed to get is-valid for {hub_id}: {e}")
30
  return None
31
 
32
 
 
43
 
44
  logger = logging.get_logger(__name__)
45
  headers = {
46
+ "authorization": f"Bearer ${HF_TOKEN}",
47
  "user-agent": USER_AGENT,
48
  }
49
+ client = Client(headers=headers)
50
+ async_client = Client(headers=headers)
51
  api = HfApi(token=HF_TOKEN)
52
 
53
 
54
+ def get_first_config_and_split_name(hub_id: str):
55
  try:
56
+ resp = client.get(
57
  f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
58
  )
59
+
60
  data = resp.json()
61
  return data["splits"][0]["config"], data["splits"][0]["split"]
62
  except Exception as e:
 
64
  return None
65
 
66
 
67
+ def check_dataset_has_non_default_file(hub_id):
68
+ try:
69
+ info = dataset_info(hub_id)
70
+ if files := info.siblings:
71
+ file_names = [f.rfilename for f in files]
72
+ files = [f for f in file_names if not f.startswith(".") or f == "README.md"]
73
+ return len(files) >= 1
74
+ return False
75
+ except Exception as e:
76
+ logger.error(f"Failed to get siblings for {hub_id}: {e}")
77
+ return False
78
+
79
+
80
+ def datasets_server_valid_rows(hub_id: str):
81
+ try:
82
+ resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
83
+ return resp.json()["viewer"]
84
+ except Exception:
85
+ # logger.error(f"Failed to get is-valid for {hub_id}: {e}")
86
+ return None
87
+
88
+
89
+ def dataset_is_valid(dataset):
90
+ return dataset if datasets_server_valid_rows(dataset.id) else None
91
+
92
+
93
+ def get_dataset_info(hub_id: str, config: str | None = None):
94
  if config is None:
95
+ config = get_first_config_and_split_name(hub_id)
96
  if config is None:
97
  return None
98
  else:
99
  config = config[0]
100
+ resp = client.get(
101
  f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
102
  )
103
  resp.raise_for_status()
104
  return resp.json()
105
 
106
 
107
+ def dataset_with_info(dataset):
108
  try:
109
+ if info := get_dataset_info(dataset.id):
110
  columns = info.get("dataset_info", {}).get("features", {})
111
  if columns is not None:
112
  return {
113
  "hub_id": dataset.id,
114
  "column_names": list(columns.keys()),
115
  "columns": columns,
116
+ # "dataset": dataset,
117
+ # "full_info": info,
118
  "likes": dataset.likes,
119
  "downloads": dataset.downloads,
120
  "created_at": dataset.created_at,
 
125
  return None
126
 
127
 
128
+ def return_dataset_with_non_default_files(dataset):
129
+ return dataset if check_dataset_has_non_default_file(dataset.id) else None
130
 
131
 
132
  @lru_cache(maxsize=100)
133
+ def prep_data():
134
+ datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1))
135
  print(f"Found {len(datasets)} datasets.")
136
+ # datasets = thread_map(
137
+ # return_dataset_with_non_default_files,
138
+ # datasets,
139
+ # )
140
+ # datasets = [x for x in datasets if x is not None]
141
+ # print(f"Found {len(datasets)} datasets with non-default files.")
142
+ has_server = thread_map(
143
+ dataset_is_valid,
144
+ datasets,
145
+ )
146
+ datasets_with_server = [x for x in has_server if x is not None]
147
  print(f"Found {len(datasets_with_server)} datasets with server.")
148
+ datasets_server_data = thread_map(dataset_with_info, datasets_with_server)
 
 
149
  print(f"Found {len(datasets_server_data)} datasets with server data.")
150
  print(datasets_server_data[0])
 
151
  return datasets_server_data
152
 
153
 
154
  def filter_columns(datasets_server_data, columns=None):
155
  if columns is not None:
156
  clean = []
157
+ # check for presence of columns
158
  for dataset in datasets_server_data:
159
  if dataset is not None:
160
  target_column = dataset.get("columns", [])
 
166
  return datasets_server_data
167
 
168
 
169
+ # warm up the cache
170
+ prep_data()
171
+
172
+
173
+ def render_model_hub_link(hub_id):
174
+ link = f"https://huggingface.co/datasets/{hub_id}"
175
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>'
176
+
177
+
178
+ def predict(columns_to_filter):
179
+ datasets_server_data = prep_data()
180
  columns_to_filter = columns_to_filter.split(",")
181
  columns_to_filter = [x.strip() for x in columns_to_filter]
182
  filtered = filter_columns(
 
188
  return df
189
 
190
 
 
 
 
 
 
191
  with gr.Blocks() as demo:
192
  gr.Markdown("# Search Hugging Face datasets by column names (POC)")
193
  gr.Markdown(