Cor / app.py
Neon-tech's picture
Update app.py
528f3a5 verified
import os
import json
import time
import socket
import threading
import requests
import pyarrow.parquet as pq
import gc
from pathlib import Path
from huggingface_hub import HfApi
# ── Config ───────────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
RAW_DIR = "/data/raw"
STATE_FILE = "/data/state.json"
WORKER_TIMEOUT = 700
MAX_BUFFERED = 999999
os.makedirs(RAW_DIR, exist_ok=True)
api = HfApi(token=HF_TOKEN)
AUTH_HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
# ── Sources ───────────────────────────────────────────────────────────────────
SOURCES = [
{
"name" : "fineweb",
"type" : "hf_list",
"repo" : "HuggingFaceFW/fineweb-edu",
"prefix" : "data/CC-MAIN-2025-26",
"skip" : 5,
"take" : 10,
"text_col": "text",
},
{
"name" : "wikipedia",
"type" : "hf_list",
"repo" : "wikimedia/wikipedia",
"prefix" : "20231101.en/train-",
"skip" : 2,
"take" : 18,
"text_col": "text",
},
{
"name" : "openwebmath",
"type" : "hf_list",
"repo" : "open-web-math/open-web-math",
"prefix" : "data/train-",
"skip" : 0,
"take" : 6,
"text_col": "text",
},
{
"name" : "code",
"type" : "url_list",
"text_col": "text",
"fmt" : "jsonl",
"urls" : [
f"https://huggingface.co/buckets/Neon-tech/Dataset-arranger/resolve/by-language/{lang}/shard_{str(i).zfill(6)}.jsonl?download=true"
for lang in ["C", "C++", "Java", "Go", "Rust", "Ruby", "PHP", "SQL", "C#", "Scala", "Lua", "Perl", "CSS"]
for i in range(2)
],
},
]
# ── Keep-alive ────────────────────────────────────────────────────────────────
def serve():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("0.0.0.0", 7860))
s.listen(5)
print("βœ“ Listening on port 7860")
while True:
conn, _ = s.accept()
conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
conn.close()
# ── State ─────────────────────────────────────────────────────────────────────
def load_state():
if os.path.exists(STATE_FILE):
with open(STATE_FILE) as f:
state = json.load(f)
shards = state["shards"]
queue = state.get("queue", [])
done = sum(1 for v in shards.values() if v["status"] == "done")
claimed = sum(1 for v in shards.values() if v["status"] == "claimed")
pending = sum(1 for v in shards.values() if v["status"] == "pending")
print(f"Resuming β€” {done} done / {claimed} claimed / {pending} buffered / {len(queue)} queued")
else:
state = {"shards": {}, "queue": []}
print("Starting fresh")
return state
def save_state(state):
tmp = STATE_FILE + ".tmp"
with open(tmp, "w") as f:
json.dump(state, f, indent=2)
os.replace(tmp, STATE_FILE)
# ── Discover ──────────────────────────────────────────────────────────────────
def discover_all(state):
known_urls = {v["url"] for v in state["shards"].values()} | {e["url"] for e in state.get("queue", [])}
new_count = 0
for src in SOURCES:
name = src["name"]
print(f"\nDiscovering: {name}")
if src["type"] == "hf_list":
all_files = sorted([
f for f in api.list_repo_files(src["repo"], repo_type="dataset")
if f.startswith(src["prefix"]) and f.endswith(".parquet")
])
selected = all_files[src["skip"]: src["skip"] + src["take"]]
base_url = f"https://huggingface.co/datasets/{src['repo']}/resolve/main/"
urls = [base_url + f for f in selected]
fmt = "parquet"
else:
urls = src["urls"]
fmt = src.get("fmt", "parquet")
added = 0
for url in urls:
if url not in known_urls:
state["queue"].append({
"url" : url,
"source" : name,
"text_col" : src["text_col"],
"fmt" : fmt,
})
known_urls.add(url)
new_count += 1
added += 1
print(f" {name}: {len(urls)} files | {added} new added to queue")
save_state(state)
print(f"\nTotal queued: {len(state['queue'])} | In state: {len(state['shards'])}")
# ── Reclaim stale ─────────────────────────────────────────────────────────────
def reclaim_stale(state):
now = time.time()
reclaimed = 0
for name, info in state["shards"].items():
if info["status"] == "claimed" and info.get("claimed_at"):
if now - info["claimed_at"] > WORKER_TIMEOUT:
print(f" ⚠ Reclaiming: {name}")
info["status"] = "pending"
info["worker"] = None
info["claimed_at"] = None
reclaimed += 1
if reclaimed:
save_state(state)
# ── Parquet β†’ JSONL ───────────────────────────────────────────────────────────
def parquet_to_jsonl(parquet_path, jsonl_path, text_col):
"""Stream parquet batch by batch β†’ write one JSON line per doc. No full load."""
pf = pq.ParquetFile(parquet_path)
n_written = 0
with open(jsonl_path, "w", encoding="utf-8") as out:
for batch in pf.iter_batches(batch_size=1_000, columns=[text_col]):
texts = batch.column(text_col).to_pylist()
for text in texts:
if text and isinstance(text, str) and text.strip():
out.write(json.dumps({"text": text.strip()}, ensure_ascii=False) + "\n")
n_written += 1
del texts
gc.collect()
return n_written
# ── Download loop ─────────────────────────────────────────────────────────────
def download_loop(state):
while True:
try:
with open(STATE_FILE) as f:
fresh = json.load(f)
state["shards"] = fresh["shards"]
state["queue"] = fresh.get("queue", [])
except Exception:
pass
reclaim_stale(state)
buffered = sum(1 for v in state["shards"].values() if v["status"] == "pending")
if buffered >= MAX_BUFFERED:
time.sleep(30)
continue
if not state["queue"]:
done = sum(1 for v in state["shards"].values() if v["status"] == "done")
total = len(state["shards"])
if done == total and total > 0:
print("βœ“ All shards complete!")
break
print(" Queue empty β€” sleeping...")
time.sleep(60)
continue
entry = state["queue"][0]
url = entry["url"]
source = entry["source"]
text_col = entry["text_col"]
fmt = entry.get("fmt", "parquet")
lang = url.split("?")[0].split("/")[-2]
base_name = url.split("?")[0].split("/")[-1].replace(".parquet", "").replace(".jsonl", "")
shard_name = f"{source}__{base_name}_{lang}.jsonl"
jsonl_path = Path(RAW_DIR) / shard_name
tmp_path = Path(RAW_DIR) / f"{shard_name}.tmp"
print(f" Downloading: {source} | {base_name}")
try:
resp = requests.get(url, headers=AUTH_HEADERS, timeout=300, stream=True)
resp.raise_for_status()
with open(tmp_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024):
f.write(chunk)
except Exception as e:
print(f" βœ— Download failed: {e} β€” retrying in 30s")
tmp_path.unlink(missing_ok=True)
time.sleep(30)
continue
if fmt == "parquet":
print(f" Converting β†’ jsonl: {shard_name}")
try:
n = parquet_to_jsonl(tmp_path, jsonl_path, text_col)
tmp_path.unlink(missing_ok=True)
print(f" βœ“ {n:,} docs")
except Exception as e:
print(f" βœ— Convert failed: {e}")
tmp_path.unlink(missing_ok=True)
jsonl_path.unlink(missing_ok=True)
time.sleep(30)
continue
else:
tmp_path.rename(jsonl_path)
state["queue"].pop(0)
state["shards"][shard_name] = {
"status" : "pending",
"url" : url,
"source" : source,
"worker" : None,
"claimed_at": None,
"error" : None,
}
save_state(state)
print(f" βœ“ Ready: {shard_name}")
time.sleep(3)
# ── Monitor ───────────────────────────────────────────────────────────────────
def monitor_loop():
while True:
time.sleep(120)
try:
with open(STATE_FILE) as f:
s = json.load(f)
shards = s["shards"]
queue = s.get("queue", [])
done = sum(1 for v in shards.values() if v["status"] == "done")
claimed = sum(1 for v in shards.values() if v["status"] == "claimed")
pending = sum(1 for v in shards.values() if v["status"] == "pending")
total = len(shards) + len(queue)
pct = (done / total * 100) if total else 0
src_done = {}
for v in shards.values():
src = v.get("source", "?")
if v["status"] == "done":
src_done[src] = src_done.get(src, 0) + 1
print(f"[MONITOR] {done}/{total} ({pct:.1f}%) | {claimed} active | {pending} buffered | {len(queue)} queued")
for src, cnt in sorted(src_done.items()):
print(f" {src}: {cnt} done")
except Exception:
pass
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
threading.Thread(target=serve, daemon=True).start()
state = load_state()
discover_all(state)
threading.Thread(target=monitor_loop, daemon=True).start()
threading.Thread(target=download_loop, args=(state,), daemon=True).start()
while True:
time.sleep(60)