dataset-column-search-api / data_loader.py
davanstrien's picture
davanstrien HF staff
improve db
33c1203
raw
history blame
No virus
4.18 kB
import os
from datetime import datetime
from typing import Any, Dict, List
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import HfApi
from huggingface_hub.utils import logging
from tqdm.auto import tqdm
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
USER_AGENT = os.getenv("USER_AGENT")
assert (
USER_AGENT is not None
), "You need to set USER_AGENT in your environment variables"
logger = logging.get_logger(__name__)
api = HfApi(token=HF_TOKEN)
MAX_DATASETS = None
def has_card_data(dataset):
return hasattr(dataset, "card_data")
def check_dataset_has_dataset_info(dataset):
return bool(
has_card_data(dataset)
and hasattr(dataset.card_data, "dataset_info")
and dataset.card_data.dataset_info is not None
)
def parse_single_config_dataset(data):
config_name = data.get("config_name", "default")
features = data.get("features", [])
column_names = [feature.get("name") for feature in features]
return {
"config_name": config_name,
"column_names": column_names,
"features": features,
}
def parse_multiple_config_dataset(data: List[Dict[str, Any]]):
return [parse_single_config_dataset(d) for d in data]
def parse_dataset(dataset):
hub_id = dataset.id
likes = dataset.likes
downloads = dataset.downloads
tags = dataset.tags
created_at = dataset.created_at
last_modified = dataset.last_modified
license = dataset.card_data.license
language = dataset.card_data.language
return {
"hub_id": hub_id,
"likes": likes,
"downloads": downloads,
"tags": tags,
"created_at": created_at,
"last_modified": last_modified,
"license": license,
"language": language,
}
def parsed_column_info(dataset_info):
if isinstance(dataset_info, dict):
return [parse_single_config_dataset(dataset_info)]
elif isinstance(dataset_info, list):
return parse_multiple_config_dataset(dataset_info)
return None
def ensure_list_of_strings(value):
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value]
return [str(value)]
def refresh_data() -> List[Dict[str, Any]]:
# current date as string
now = datetime.now()
# check if a file for the current date exists
if os.path.exists(f"datasets_{now.strftime('%Y-%m-%d')}.parquet"):
df = pd.read_parquet(f"datasets_{now.strftime('%Y-%m-%d')}.parquet")
return df.to_dict(orient="records")
# List all datasets
datasets = list(api.list_datasets(limit=MAX_DATASETS, full=True))
# Filter datasets with dataset info
datasets = [
dataset for dataset in tqdm(datasets) if check_dataset_has_dataset_info(dataset)
]
parsed_datasets = []
for dataset in tqdm(datasets):
try:
datasetinfo = parse_dataset(dataset)
column_info = parsed_column_info(dataset.card_data.dataset_info)
parsed_datasets.extend({**datasetinfo, **info} for info in column_info)
except Exception as e:
print(f"Error processing dataset {dataset.id}: {e}")
continue
# Convert to DataFrame
df = pd.DataFrame(parsed_datasets)
# Ensure 'license', 'tags', and 'language' are lists of strings
df["license"] = df["license"].apply(ensure_list_of_strings)
df["tags"] = df["tags"].apply(ensure_list_of_strings)
df["language"] = df["language"].apply(ensure_list_of_strings)
# Ensure 'column_names' is a list
df["column_names"] = df["column_names"].apply(
lambda x: x if isinstance(x, list) else []
)
df = df.astype({"hub_id": "string", "config_name": "string"})
# save to parquet file with current date
# df.to_parquet(f"datasets_{now.strftime('%Y-%m-%d')}.parquet")
# # save to JSON file with current date
# df.to_json(
# f"datasets_{now.strftime('%Y-%m-%d')}.json", orient="records", lines=True
# )
# return a list of dictionaries
return df.to_dict(orient="records")