Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # sync_library_and_hf.py | |
| '''' | |
| RUN BELOW FOR NEW HTML FILES TO UPDATE OLD ONES ON DFATASET REPO | |
| python sync_library_and_hf.py --db-path library.csv --repo-id akazemian/audio-html --model-name wavcoch_audio-preds-sr=16000 --index-filename index.csv --wipe-remote --wipe-local | |
| ''' | |
| import argparse, datetime, uuid, posixpath, sys, traceback, os, hashlib | |
| from pathlib import Path | |
| from typing import List, Tuple, Set | |
| from urllib.parse import unquote | |
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from huggingface_hub import ( | |
| HfApi, | |
| hf_hub_download, | |
| CommitOperationAdd, | |
| CommitOperationDelete, | |
| ) | |
| from huggingface_hub.utils import HfHubHTTPError | |
| REQUIRED_DB_COLS = [ | |
| "id","filename","path","tags","keywords","notes","uploaded_at","category","dataset","hf_path" | |
| ] | |
| INDEX_COLS = ["id","filename","relpath","category","dataset","tags","keywords","notes","uploaded_at"] | |
| # --- manifest helpers --- | |
| AUDIO_EXTS = {".wav", ".mp3"} # extend if needed: ".flac", ".ogg", etc. | |
| def _strip_ext(name: str, exts: set[str]) -> str: | |
| n = name | |
| for ext in exts: | |
| if n.lower().endswith(ext): | |
| return n[: -len(ext)] | |
| return n | |
| def key_from_html_filename(fname: str) -> str: | |
| # e.g. "foo_bar.html" -> "foo_bar" | |
| base = Path(fname).name | |
| if base.lower().endswith(".html"): | |
| base = base[:-5] | |
| return base | |
| def key_from_manifest_filename(fname: str) -> str: | |
| # e.g. "foo_bar.wav" or "foo_bar.mp3" -> "foo_bar" | |
| base = Path(fname).name | |
| return _strip_ext(base, AUDIO_EXTS) | |
| def create_file_specific_manifest(csv_path: Path) -> pd.DataFrame: | |
| audio_dir = "/data/atlask/BAU-Quant/val" | |
| manifest = pd.read_csv(csv_path) | |
| mask = manifest['dataset'].eq('TUT_urban_acoustic_scenes') | |
| manifest['audio_category'] = np.where(mask, manifest['dataset'], manifest['audio_category']) | |
| manifest = manifest.assign( | |
| audio_category = manifest['audio_category'].where(~mask, manifest['dataset']) | |
| ) | |
| # 1) Build a files dataframe | |
| files = pd.DataFrame({"file_name": os.listdir(audio_dir)}) | |
| # keep only audio files if needed | |
| files = files[files["file_name"].str.lower().str.endswith((".wav", ".mp3", ".flac", ".ogg", ".m4a", ".opus"))].copy() | |
| files["file_path"] = files["file_name"].apply(lambda f: os.path.join(audio_dir, f)) | |
| # Normalize to a join key: drop extension, then strip `_chunk...` | |
| files["key"] = ( | |
| files["file_name"] | |
| .str.replace(r"\.[^.]+$", "", regex=True) # remove extension | |
| .str.replace(r"_chunk.*$", "", regex=True) # remove _chunk suffix if present | |
| ) | |
| # 2) Prepare manifest with the same key | |
| man = manifest.copy() | |
| # If manifest['file_name'] includes extensions / chunk suffixes, normalize the same way: | |
| man["key"] = ( | |
| man["file_name"] | |
| .str.replace(r"\.[^.]+$", "", regex=True) | |
| .str.replace(r"_chunk.*$", "", regex=True) | |
| ) | |
| # If duplicates exist in manifest for the same key, decide how to resolve: | |
| # e.g., keep first occurrence | |
| man = man.drop_duplicates(subset="key", keep="first") | |
| # 3) Merge once (vectorized) | |
| cols_to_take = ["sr", "dataset", "audio_category", "split", "duration_s"] | |
| out = files.merge(man[["key"] + cols_to_take], on="key", how="left") | |
| # 4) Final column order | |
| return out[["sr", "file_name", "file_path", "dataset", "audio_category", "split"]] | |
| def load_manifest_map(csv_path: Path) -> dict[str, tuple[str, str]]: | |
| """ | |
| Returns {basename_key: (dataset, category)} from the manifest. | |
| Manifest must have columns: file_name, dataset, category | |
| """ | |
| # if not csv_path.exists(): | |
| # print(f"[manifest] WARNING: not found: {csv_path}") | |
| # return {} | |
| # dfm = pd.read_csv(csv_path) | |
| dfm = create_file_specific_manifest(csv_path) | |
| dfm = dfm.rename(columns={'audio_category':'category'}) | |
| required = {"file_name", "dataset", "category"} | |
| missing = required - set(dfm.columns) | |
| if missing: | |
| raise ValueError(f"manifest missing columns: {sorted(missing)}") | |
| m = {} | |
| for _, r in dfm.iterrows(): | |
| k = key_from_manifest_filename(str(r["file_name"])) | |
| ds = str(r["dataset"]) if pd.notna(r["dataset"]) else "" | |
| cat = str(r["category"]) if pd.notna(r["category"]) else "" | |
| if k and (ds or cat): | |
| m[k] = (ds, cat) | |
| print(f"[manifest] loaded {len(m)} keys from {csv_path}") | |
| return m | |
| def now_iso() -> str: | |
| return datetime.datetime.now().isoformat(timespec="seconds") | |
| def ensure_cols(df: pd.DataFrame, cols: list) -> pd.DataFrame: | |
| for c in cols: | |
| if c not in df.columns: | |
| df[c] = "" | |
| for c in cols: | |
| df[c] = df[c].fillna("").astype(str) | |
| return df[cols] | |
| def load_db(db_path: Path) -> pd.DataFrame: | |
| if db_path.exists(): | |
| df = pd.read_csv(db_path) | |
| else: | |
| df = pd.DataFrame(columns=REQUIRED_DB_COLS) | |
| return ensure_cols(df, REQUIRED_DB_COLS) | |
| def save_db(df: pd.DataFrame, db_path: Path): | |
| db_path.parent.mkdir(parents=True, exist_ok=True) | |
| df.to_csv(db_path, index=False) | |
| def load_hf_index(repo_id: str, index_filename: str) -> Tuple[pd.DataFrame, bool]: | |
| try: | |
| p = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=index_filename) | |
| df = pd.read_csv(p) | |
| return ensure_cols(df, INDEX_COLS), True | |
| except HfHubHTTPError as e: | |
| if e.response is not None and e.response.status_code == 404: | |
| return ensure_cols(pd.DataFrame(columns=INDEX_COLS), INDEX_COLS), False | |
| raise | |
| def relpath_posix(local_path: Path, root: Path) -> str: | |
| rel = local_path.resolve().relative_to(root.resolve()) | |
| parts = [unquote(p) for p in rel.as_posix().split("/")] | |
| return posixpath.join(*parts) | |
| # --- model prefix + sharding helpers --- | |
| def ensure_model_prefix(relpath: str, model_name: str | None) -> str: | |
| """ | |
| If model_name is provided and relpath doesn't start with "<model_name>/", | |
| prepend it. Otherwise return relpath unchanged. | |
| """ | |
| if not model_name: | |
| return relpath | |
| model = model_name.strip() | |
| if not model: | |
| return relpath | |
| if relpath.startswith(model + "/"): | |
| return relpath | |
| return f"{model}/{relpath}" | |
| def shard_relpath_under_model(relpath: str, hexdigits: int = 2) -> str: | |
| """ | |
| Insert shard bucket immediately after the *model* segment (first path part). | |
| If there is only 1 segment, just return relpath. | |
| """ | |
| parts = relpath.split("/") | |
| if len(parts) < 2: | |
| return relpath | |
| filename = parts[-1] | |
| bucket = hashlib.sha1(filename.encode("utf-8")).hexdigest()[:hexdigits] | |
| # parts[0] = model, parts[1:] = rest of path | |
| return "/".join([parts[0], bucket] + parts[1:]) | |
| def discover_new_local_htmls(reports_root: Path, df_db: pd.DataFrame) -> List[Path]: | |
| all_htmls = list(reports_root.rglob("*.html")) | |
| existing_paths = set(df_db["path"].astype(str)) | |
| return sorted([p for p in all_htmls if str(p) not in existing_paths]) | |
| def rows_from_files( | |
| files: List[Path], | |
| reports_root: Path, | |
| manifest_map: dict[str, tuple[str,str]], | |
| ) -> pd.DataFrame: | |
| ts = now_iso() | |
| rows = [] | |
| for p in files: | |
| k = key_from_html_filename(p.name) | |
| ds, cat = manifest_map.get(k, ("", "")) | |
| rows.append({ | |
| "id": uuid.uuid4().hex[:8], | |
| "filename": p.name, | |
| "path": str(p), | |
| "tags": "", | |
| "keywords": "", | |
| "notes": "", | |
| "uploaded_at": ts, | |
| "category": cat, | |
| "dataset": ds, | |
| "hf_path": "", | |
| }) | |
| return pd.DataFrame(rows, columns=REQUIRED_DB_COLS) if rows else pd.DataFrame(columns=REQUIRED_DB_COLS) | |
| def backfill_hf_paths_by_relpath( | |
| df_db: pd.DataFrame, | |
| reports_root: Path, | |
| hf_repo: str, | |
| idx: pd.DataFrame, | |
| model_name: str | None, | |
| do_shard: bool, | |
| shard_digits: int, | |
| ) -> int: | |
| """ | |
| For each local file path, compute the *target* repo relpath exactly as we upload it | |
| (model prefix + optional shard). If that relpath appears in index.csv, backfill hf_path. | |
| """ | |
| rel_set = set(idx["relpath"].astype(str)) | |
| updated = 0 | |
| for i, p in enumerate(df_db["path"].astype(str).tolist()): | |
| if not p: | |
| continue | |
| lp = Path(p) | |
| if not lp.exists(): | |
| continue | |
| try: | |
| base_rp = relpath_posix(lp, reports_root) # e.g. "file.html" or "model/.../file.html" | |
| except Exception: | |
| continue | |
| base_rp = ensure_model_prefix(base_rp, model_name) # ensure "<model>/..." | |
| rp_target = shard_relpath_under_model(base_rp, shard_digits) if do_shard else base_rp | |
| if rp_target in rel_set and not df_db.at[i, "hf_path"]: | |
| df_db.at[i, "hf_path"] = f"hf://{hf_repo}/{rp_target}" | |
| updated += 1 | |
| return updated | |
| def backfill_hf_paths_by_filename(df_db: pd.DataFrame, hf_repo: str, idx: pd.DataFrame) -> int: | |
| updated = 0 | |
| rel_by_fname = dict(zip(idx["filename"].astype(str), idx["relpath"].astype(str))) | |
| mask = df_db["hf_path"].astype(str) == "" | |
| for i in df_db.index[mask]: | |
| fn = str(df_db.at[i, "filename"]) | |
| rp = rel_by_fname.get(fn) | |
| if rp: | |
| df_db.at[i, "hf_path"] = f"hf://{hf_repo}/{rp}" | |
| updated += 1 | |
| return updated | |
| def append_to_remote_index(remote_index: pd.DataFrame, new_rows: List[dict]) -> pd.DataFrame: | |
| if not new_rows: | |
| return remote_index | |
| add_df = pd.DataFrame(new_rows, columns=INDEX_COLS) | |
| merged = pd.concat([remote_index, add_df], ignore_index=True) | |
| merged = merged.drop_duplicates(subset=["relpath"], keep="first") | |
| return merged[INDEX_COLS] | |
| def list_remote_relpaths(api: HfApi, repo_id: str) -> Set[str]: | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") | |
| out = set() | |
| for f in files: | |
| parts = [unquote(s) for s in f.split("/")] | |
| out.add("/".join(parts)) | |
| return out | |
| def commit_ops_in_batches(api: HfApi, repo_id: str, ops: List, batch_size: int, msg_prefix: str): | |
| if not ops: | |
| return | |
| for start in range(0, len(ops), batch_size): | |
| batch = ops[start:start+batch_size] | |
| api.create_commit( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| operations=batch, | |
| commit_message=f"{msg_prefix} (n={len(batch)})" | |
| ) | |
| # ---------- Wipe helpers ---------- | |
| def wipe_remote_dataset(api: HfApi, repo_id: str, keep: Set[str], batch_size: int, dry: bool): | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") | |
| to_delete = [] | |
| for f in files: | |
| f_norm = "/".join([unquote(s) for s in f.split("/")]) | |
| if f_norm in keep: | |
| continue | |
| to_delete.append(CommitOperationDelete(path_in_repo=f_norm)) | |
| if not to_delete: | |
| print("[wipe] nothing to delete") | |
| return | |
| if dry: | |
| print(f"[dry-run] would delete {len(to_delete)} files from {repo_id}") | |
| return | |
| print(f"[wipe] deleting {len(to_delete)} files from {repo_id} ...") | |
| commit_ops_in_batches(api, repo_id, to_delete, batch_size, "Wipe dataset") | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Reset and sync HF dataset from local HTMLs (optionally wipe repo), shard to avoid 10k/dir limit, update index.csv, backfill hf_path.") | |
| ap.add_argument("--reports-root", default='/data/atlask/Model-Preds-Html/AudioSet-Audio', type=Path, help="Root containing {model}/.../*.html (or just the model dir)") | |
| ap.add_argument("--db-path", required=True, type=Path, help="Path to local library.csv") | |
| ap.add_argument("--manifest-csv", default="/data/atlask/BAU-Quant/manifest_val.csv", type=Path, help="CSV with columns file_name,dataset,category; matched by basename without extension") | |
| ap.add_argument("--repo-id", required=True, help="HF dataset repo id, e.g. USER/audio-html") | |
| ap.add_argument("--index-filename", default="index.csv", help="Index filename in the HF dataset (default: index.csv)") | |
| ap.add_argument("--batch-size", type=int, default=1000, help="Files per commit when uploading to HF") | |
| ap.add_argument("--dry-run", action="store_true", help="Print actions; do not write or push") | |
| ap.add_argument("--commit-message", default="Sync: add new HTMLs + update index.csv", help="Commit message prefix") | |
| # Reset/Wipe options | |
| ap.add_argument("--wipe-remote", action="store_true", help="Delete ALL files in the HF dataset before uploading") | |
| ap.add_argument("--keep", action="append", default=[], help="Paths to keep during wipe (can be passed multiple times)") | |
| ap.add_argument("--wipe-local", action="store_true", help="Delete local library.csv before scanning") | |
| # SHARD controls | |
| ap.add_argument("--no-shard", action="store_true", help="Disable sharding (NOT recommended; risk 10k/dir limit)") | |
| ap.add_argument("--shard-hexdigits", type=int, default=2, help="Digits of SHA1 prefix for bucket (default: 2 -> 256 buckets)") | |
| # Model prefix | |
| ap.add_argument("--model-name", type=str, default=None, | |
| help="Force all uploaded relpaths to be prefixed with this model folder (use if reports-root is already inside the model).") | |
| args = ap.parse_args() | |
| reports_root: Path = args.reports_root | |
| db_path: Path = args.db_path | |
| hf_repo: str = args.repo_id | |
| index_filename: str = args.index_filename | |
| bs: int = args.batch_size | |
| dry: bool = args.dry_run | |
| do_shard: bool = not args.no_shard | |
| shard_digits: int = max(1, args.shard_hexdigits) | |
| keep_set: Set[str] = set(args.keep) | |
| print(f"[config] reports_root={reports_root}") | |
| print(f"[config] db_path={db_path}") | |
| print(f"[config] repo_id={hf_repo}, index={index_filename}") | |
| print(f"[config] batch_size={bs}, dry_run={dry}, shard={'on' if do_shard else 'off'}:{shard_digits}") | |
| if args.model_name: | |
| print(f"[config] model_name={args.model_name}") | |
| if keep_set: | |
| print(f"[config] wipe keep-list: {sorted(keep_set)}") | |
| if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1": | |
| print("[tip] For faster uploads, install `hf-transfer` and set HF_HUB_ENABLE_HF_TRANSFER=1") | |
| api = HfApi() | |
| manifest_map = load_manifest_map(args.manifest_csv) | |
| # 0) Optional wipes | |
| if args.wipe_remote: | |
| wipe_remote_dataset(api, hf_repo, keep_set, bs, dry) | |
| if args.wipe_local and db_path.exists(): | |
| if dry: | |
| print(f"[dry-run] would remove local DB: {db_path}") | |
| else: | |
| print(f"[wipe] removing local DB: {db_path}") | |
| try: | |
| db_path.unlink() | |
| except FileNotFoundError: | |
| pass | |
| # 1) Load DB (fresh if wiped) | |
| df_db = load_db(db_path) | |
| # 2) Append new local *.html files to DB | |
| new_local_files = discover_new_local_htmls(reports_root, df_db) | |
| print(f"[scan] new local HTML files: {len(new_local_files)}") | |
| if new_local_files: | |
| df_new = rows_from_files(new_local_files, reports_root, manifest_map) | |
| df_db = pd.concat([df_db, df_new], ignore_index=True) | |
| # 3) Load remote index + list files (will be empty after wipe) | |
| remote_index, existed = load_hf_index(hf_repo, index_filename) | |
| print(f"[index] remote exists={existed}, rows={len(remote_index)}") | |
| remote_files_set = list_remote_relpaths(api, hf_repo) | |
| print(f"[remote] files in repo: {len(remote_files_set)}") | |
| # 4) Backfill hf_path (now uses model prefix + shard) | |
| n1 = backfill_hf_paths_by_relpath( | |
| df_db, reports_root, hf_repo, remote_index, | |
| model_name=args.model_name, | |
| do_shard=do_shard, | |
| shard_digits=shard_digits, | |
| ) | |
| n2 = backfill_hf_paths_by_filename(df_db, hf_repo, remote_index) | |
| print(f"[hf] backfilled hf_path: by_relpath={n1}, by_filename={n2}") | |
| # 5) Decide which rows to upload (and target relpaths, sharded under model) | |
| need_upload = [] | |
| for i, r in df_db.iterrows(): | |
| # If you've wiped, hf_path will be empty; we only upload files that exist locally | |
| local = Path(str(r["path"])) | |
| if (not local) or (not local.exists()): | |
| continue | |
| try: | |
| base_rp = relpath_posix(local, reports_root) # "file.html" or "model/.../file.html" | |
| except Exception: | |
| continue | |
| base_rp = ensure_model_prefix(base_rp, args.model_name) # ensure "<model>/..." | |
| rp = shard_relpath_under_model(base_rp, shard_digits) if do_shard else base_rp | |
| if rp not in remote_files_set: | |
| need_upload.append((i, r.to_dict(), rp)) | |
| print(f"[hf] rows needing upload (not present in repo): {len(need_upload)}") | |
| ops: List[CommitOperationAdd] = [] | |
| new_index_rows: List[dict] = [] | |
| for i, rdict, rp in need_upload: | |
| local = Path(rdict["path"]) | |
| if not local.exists(): | |
| continue | |
| ops.append(CommitOperationAdd(path_in_repo=rp, path_or_fileobj=str(local))) | |
| # derive from HTML filename using manifest map | |
| k = key_from_html_filename(rdict["filename"]) | |
| ds, cat = manifest_map.get(k, (str(rdict["dataset"]), str(rdict["category"]))) | |
| new_index_rows.append({ | |
| "id": rdict["id"] or uuid.uuid4().hex[:8], | |
| "filename": rdict["filename"], | |
| "relpath": rp, | |
| "category": cat, | |
| "dataset": ds, | |
| "tags": rdict["tags"], | |
| "keywords": rdict["keywords"], | |
| "notes": rdict["notes"], | |
| "uploaded_at": rdict["uploaded_at"] or now_iso(), | |
| }) | |
| # 6) Upload in batches | |
| if ops and not dry: | |
| print(f"[hf] uploading {len(ops)} files in batches of {bs}...") | |
| commit_ops_in_batches(api, hf_repo, ops, bs, args.commit_message) | |
| remote_files_set = list_remote_relpaths(api, hf_repo) # refresh | |
| elif ops and dry: | |
| print(f"[dry-run] would upload {len(ops)} files") | |
| # 7) Compose index.csv (fresh if wiped) | |
| current_index_rel = set(remote_index["relpath"].astype(str)) | |
| current_index_rel.update([row["relpath"] for row in new_index_rows]) | |
| missing_in_index = [rp for rp in remote_files_set if rp.endswith(".html") and rp not in current_index_rel] | |
| if missing_in_index: | |
| print(f"[index] adding {len(missing_in_index)} repo files that were missing from index.csv") | |
| for rp in missing_in_index: | |
| fname = Path(rp).name | |
| k = key_from_html_filename(fname) | |
| ds, cat = manifest_map.get(k, ("", "")) | |
| new_index_rows.append({ | |
| "id": uuid.uuid4().hex[:8], | |
| "filename": fname, | |
| "relpath": rp, | |
| "category": cat, | |
| "dataset": ds, | |
| "tags": "", | |
| "keywords": "", | |
| "notes": "", | |
| "uploaded_at": now_iso(), | |
| }) | |
| if new_index_rows or args.wipe_remote: | |
| # If wiped, overwrite index.csv with just merged content | |
| base_index = remote_index if not args.wipe_remote else pd.DataFrame(columns=INDEX_COLS) | |
| merged_index = append_to_remote_index(base_index, new_index_rows) | |
| merged_index = ensure_cols(merged_index, INDEX_COLS) | |
| if not dry: | |
| tmp = Path("index.updated.csv") | |
| merged_index.to_csv(tmp, index=False) | |
| api.create_commit( | |
| repo_id=hf_repo, | |
| repo_type="dataset", | |
| operations=[CommitOperationAdd(path_in_repo=index_filename, path_or_fileobj=str(tmp))], | |
| commit_message=f"{args.commit_message} (update {index_filename}, rows={len(merged_index)})" | |
| ) | |
| tmp.unlink(missing_ok=True) | |
| else: | |
| print(f"[dry-run] would write fresh {index_filename} with {len(merged_index)} rows") | |
| # 8) Update local hf_path for rows now on HF (sharded + model-prefixed) | |
| for i, r in df_db.iterrows(): | |
| if str(r.get("hf_path", "")): | |
| continue | |
| local = str(r["path"]) | |
| if not local: | |
| continue | |
| p = Path(local) | |
| if not p.exists(): | |
| continue | |
| try: | |
| base_rp = relpath_posix(p, reports_root) | |
| except Exception: | |
| continue | |
| base_rp = ensure_model_prefix(base_rp, args.model_name) | |
| rp = shard_relpath_under_model(base_rp, shard_digits) if do_shard else base_rp | |
| if rp in remote_files_set: | |
| df_db.at[i, "hf_path"] = f"hf://{hf_repo}/{rp}" | |
| # 8.5) Backfill dataset/category in DB from manifest if missing | |
| mask_missing = (df_db["dataset"].astype(str) == "") | (df_db["category"].astype(str) == "") | |
| for i, r in df_db[mask_missing].iterrows(): | |
| k = key_from_html_filename(str(r["filename"])) | |
| if k in manifest_map: # manifest_map was loaded earlier: load_manifest_map(args.manifest_csv) | |
| ds, cat = manifest_map[k] | |
| if not str(r["dataset"]): | |
| df_db.at[i, "dataset"] = ds | |
| if not str(r["category"]): | |
| df_db.at[i, "category"] = cat | |
| # 9) Save DB | |
| if dry: | |
| print("[dry-run] not writing library.csv") | |
| else: | |
| save_db(df_db, db_path) | |
| print(f"[done] wrote {len(df_db)} rows to {db_path}") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as e: | |
| traceback.print_exc() | |
| sys.exit(1) | |