davanstrien's picture
davanstrien HF staff
Update app.py
4a3355b verified
from dotenv import load_dotenv
import os
import pandas as pd
from httpx import Client
from huggingface_hub import dataset_info
from huggingface_hub.utils import logging
from functools import lru_cache
from tqdm.contrib.concurrent import thread_map
from huggingface_hub import HfApi
from rich import print
import gradio as gr
def check_dataset_has_non_default_file(hub_id):
info = dataset_info(hub_id)
if files := info.siblings:
file_names = [f.rfilename for f in files]
files = [f for f in file_names if not f.startswith(".") or f == "README.md"]
return len(files) >= 1
return False
def datasets_server_valid_rows(hub_id: str):
try:
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
return resp.json()["viewer"]
except Exception:
# logger.error(f"Failed to get is-valid for {hub_id}: {e}")
return None
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
USER_AGENT = os.getenv("USER_AGENT")
assert (
USER_AGENT is not None
), "You need to set USER_AGENT in your environment variables"
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
logger = logging.get_logger(__name__)
headers = {
"authorization": f"Bearer ${HF_TOKEN}",
"user-agent": USER_AGENT,
}
client = Client(headers=headers)
async_client = Client(headers=headers)
api = HfApi(token=HF_TOKEN)
def get_first_config_and_split_name(hub_id: str):
try:
resp = client.get(
f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
)
data = resp.json()
return data["splits"][0]["config"], data["splits"][0]["split"]
except Exception as e:
logger.error(f"Failed to get splits for {hub_id}: {e}")
return None
def check_dataset_has_non_default_file(hub_id):
try:
info = dataset_info(hub_id)
if files := info.siblings:
file_names = [f.rfilename for f in files]
files = [f for f in file_names if not f.startswith(".") or f == "README.md"]
return len(files) >= 1
return False
except Exception as e:
logger.error(f"Failed to get siblings for {hub_id}: {e}")
return False
def datasets_server_valid_rows(hub_id: str):
try:
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
return resp.json()["viewer"]
except Exception:
# logger.error(f"Failed to get is-valid for {hub_id}: {e}")
return None
def dataset_is_valid(dataset):
return dataset if datasets_server_valid_rows(dataset.id) else None
def get_dataset_info(hub_id: str, config: str | None = None):
if config is None:
config = get_first_config_and_split_name(hub_id)
if config is None:
return None
else:
config = config[0]
resp = client.get(
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
)
resp.raise_for_status()
return resp.json()
def dataset_with_info(dataset):
try:
if info := get_dataset_info(dataset.id):
columns = info.get("dataset_info", {}).get("features", {})
if columns is not None:
return {
"hub_id": dataset.id,
"column_names": list(columns.keys()),
"columns": columns,
# "dataset": dataset,
# "full_info": info,
"likes": dataset.likes,
"downloads": dataset.downloads,
"created_at": dataset.created_at,
"tags": dataset.tags,
}
except Exception as e:
logger.error(f"Failed to get info for {dataset.id}: {e}")
return None
def return_dataset_with_non_default_files(dataset):
return dataset if check_dataset_has_non_default_file(dataset.id) else None
@lru_cache(maxsize=100)
def prep_data():
datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1))
print(f"Found {len(datasets)} datasets.")
# datasets = thread_map(
# return_dataset_with_non_default_files,
# datasets,
# )
# datasets = [x for x in datasets if x is not None]
# print(f"Found {len(datasets)} datasets with non-default files.")
has_server = thread_map(
dataset_is_valid,
datasets,
)
datasets_with_server = [x for x in has_server if x is not None]
print(f"Found {len(datasets_with_server)} datasets with server.")
datasets_server_data = thread_map(dataset_with_info, datasets_with_server)
print(f"Found {len(datasets_server_data)} datasets with server data.")
print(datasets_server_data[0])
return datasets_server_data
def filter_columns(datasets_server_data, columns=None):
if columns is not None:
clean = []
# check for presence of columns
for dataset in datasets_server_data:
if dataset is not None:
target_column = dataset.get("columns", [])
if target_column is not None and set(columns).issubset(
set(target_column)
):
clean.append(dataset)
return clean
return datasets_server_data
# warm up the cache
prep_data()
def render_model_hub_link(hub_id):
link = f"https://huggingface.co/datasets/{hub_id}"
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>'
def predict(columns_to_filter):
datasets_server_data = prep_data()
columns_to_filter = columns_to_filter.split(",")
columns_to_filter = [x.strip() for x in columns_to_filter]
filtered = filter_columns(
datasets_server_data,
columns=columns_to_filter,
)
df = pd.DataFrame(filtered)
df["hub_id"] = df["hub_id"].apply(render_model_hub_link)
return df
with gr.Blocks() as demo:
gr.Markdown("# Search Hugging Face datasets by column names (POC)")
gr.Markdown(
""""This Space allows you to search Hugging Face datasets by their column names. It's a POC, but the idea is that you can often know a lot about a dataset by it's column names and types.
"""
)
columns = gr.Textbox(
"chosen,rejected", label="Columns to filter on separate with `,`"
)
btn = gr.Button("Show datasets with columns")
df = gr.DataFrame(datatype="markdown")
btn.click(predict, columns, df)
demo.launch()