Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv | |
import os | |
import pandas as pd | |
from httpx import Client | |
from huggingface_hub import dataset_info | |
from huggingface_hub.utils import logging | |
from functools import lru_cache | |
from tqdm.contrib.concurrent import thread_map | |
from huggingface_hub import HfApi | |
from rich import print | |
import gradio as gr | |
def check_dataset_has_non_default_file(hub_id): | |
info = dataset_info(hub_id) | |
if files := info.siblings: | |
file_names = [f.rfilename for f in files] | |
files = [f for f in file_names if not f.startswith(".") or f == "README.md"] | |
return len(files) >= 1 | |
return False | |
def datasets_server_valid_rows(hub_id: str): | |
try: | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") | |
return resp.json()["viewer"] | |
except Exception: | |
# logger.error(f"Failed to get is-valid for {hub_id}: {e}") | |
return None | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" | |
USER_AGENT = os.getenv("USER_AGENT") | |
assert ( | |
USER_AGENT is not None | |
), "You need to set USER_AGENT in your environment variables" | |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" | |
logger = logging.get_logger(__name__) | |
headers = { | |
"authorization": f"Bearer ${HF_TOKEN}", | |
"user-agent": USER_AGENT, | |
} | |
client = Client(headers=headers) | |
async_client = Client(headers=headers) | |
api = HfApi(token=HF_TOKEN) | |
def get_first_config_and_split_name(hub_id: str): | |
try: | |
resp = client.get( | |
f"https://datasets-server.huggingface.co/splits?dataset={hub_id}" | |
) | |
data = resp.json() | |
return data["splits"][0]["config"], data["splits"][0]["split"] | |
except Exception as e: | |
logger.error(f"Failed to get splits for {hub_id}: {e}") | |
return None | |
def check_dataset_has_non_default_file(hub_id): | |
try: | |
info = dataset_info(hub_id) | |
if files := info.siblings: | |
file_names = [f.rfilename for f in files] | |
files = [f for f in file_names if not f.startswith(".") or f == "README.md"] | |
return len(files) >= 1 | |
return False | |
except Exception as e: | |
logger.error(f"Failed to get siblings for {hub_id}: {e}") | |
return False | |
def datasets_server_valid_rows(hub_id: str): | |
try: | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") | |
return resp.json()["viewer"] | |
except Exception: | |
# logger.error(f"Failed to get is-valid for {hub_id}: {e}") | |
return None | |
def dataset_is_valid(dataset): | |
return dataset if datasets_server_valid_rows(dataset.id) else None | |
def get_dataset_info(hub_id: str, config: str | None = None): | |
if config is None: | |
config = get_first_config_and_split_name(hub_id) | |
if config is None: | |
return None | |
else: | |
config = config[0] | |
resp = client.get( | |
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" | |
) | |
resp.raise_for_status() | |
return resp.json() | |
def dataset_with_info(dataset): | |
try: | |
if info := get_dataset_info(dataset.id): | |
columns = info.get("dataset_info", {}).get("features", {}) | |
if columns is not None: | |
return { | |
"hub_id": dataset.id, | |
"column_names": list(columns.keys()), | |
"columns": columns, | |
# "dataset": dataset, | |
# "full_info": info, | |
"likes": dataset.likes, | |
"downloads": dataset.downloads, | |
"created_at": dataset.created_at, | |
"tags": dataset.tags, | |
} | |
except Exception as e: | |
logger.error(f"Failed to get info for {dataset.id}: {e}") | |
return None | |
def return_dataset_with_non_default_files(dataset): | |
return dataset if check_dataset_has_non_default_file(dataset.id) else None | |
def prep_data(): | |
datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1)) | |
print(f"Found {len(datasets)} datasets.") | |
# datasets = thread_map( | |
# return_dataset_with_non_default_files, | |
# datasets, | |
# ) | |
# datasets = [x for x in datasets if x is not None] | |
# print(f"Found {len(datasets)} datasets with non-default files.") | |
has_server = thread_map( | |
dataset_is_valid, | |
datasets, | |
) | |
datasets_with_server = [x for x in has_server if x is not None] | |
print(f"Found {len(datasets_with_server)} datasets with server.") | |
datasets_server_data = thread_map(dataset_with_info, datasets_with_server) | |
print(f"Found {len(datasets_server_data)} datasets with server data.") | |
print(datasets_server_data[0]) | |
return datasets_server_data | |
def filter_columns(datasets_server_data, columns=None): | |
if columns is not None: | |
clean = [] | |
# check for presence of columns | |
for dataset in datasets_server_data: | |
if dataset is not None: | |
target_column = dataset.get("columns", []) | |
if target_column is not None and set(columns).issubset( | |
set(target_column) | |
): | |
clean.append(dataset) | |
return clean | |
return datasets_server_data | |
# warm up the cache | |
prep_data() | |
def render_model_hub_link(hub_id): | |
link = f"https://huggingface.co/datasets/{hub_id}" | |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>' | |
def predict(columns_to_filter): | |
datasets_server_data = prep_data() | |
columns_to_filter = columns_to_filter.split(",") | |
columns_to_filter = [x.strip() for x in columns_to_filter] | |
filtered = filter_columns( | |
datasets_server_data, | |
columns=columns_to_filter, | |
) | |
df = pd.DataFrame(filtered) | |
df["hub_id"] = df["hub_id"].apply(render_model_hub_link) | |
return df | |
with gr.Blocks() as demo: | |
gr.Markdown("# Search Hugging Face datasets by column names (POC)") | |
gr.Markdown( | |
""""This Space allows you to search Hugging Face datasets by their column names. It's a POC, but the idea is that you can often know a lot about a dataset by it's column names and types. | |
""" | |
) | |
columns = gr.Textbox( | |
"chosen,rejected", label="Columns to filter on separate with `,`" | |
) | |
btn = gr.Button("Show datasets with columns") | |
df = gr.DataFrame(datatype="markdown") | |
btn.click(predict, columns, df) | |
demo.launch() | |