Neon-tech commited on
Commit
69cda21
·
verified ·
1 Parent(s): 541a74a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -7,6 +7,7 @@ import gc
7
  import ctypes
8
  import multiprocessing as mp
9
  from pathlib import Path
 
10
  import pyarrow.parquet as pq
11
  from tokenizers import Tokenizer
12
 
@@ -72,8 +73,9 @@ def tokenize_chunk(texts):
72
  def process_shard(name, raw_path, pool):
73
  print(f" [{WORKER_ID}] Processing: {name}")
74
 
75
- out_name = name.replace(".parquet", ".jsonl")
76
  out_path = Path(OUT_DIR) / out_name
 
77
  total_tokens = 0
78
 
79
  try:
@@ -83,7 +85,7 @@ def process_shard(name, raw_path, pool):
83
  return False, f"read_failed: {e}"
84
 
85
  try:
86
- with open(out_path, "w", encoding="utf-8") as f:
87
  for batch in pf.iter_batches(batch_size=5_000, columns=["text"]):
88
  texts = batch.column("text").to_pylist()
89
  mid = len(texts) // 2
@@ -91,18 +93,22 @@ def process_shard(name, raw_path, pool):
91
  try:
92
  results = pool.map(tokenize_chunk, [texts[:mid], texts[mid:]])
93
  except Exception as e:
 
94
  return False, f"tokenize_failed: {e}"
95
 
96
  for ids in results[0] + results[1]:
97
- f.write(json.dumps({"input_ids": ids}) + "\n")
 
98
  total_tokens += len(ids)
99
 
100
  del texts, results
101
  gc.collect()
102
 
103
  except Exception as e:
 
104
  return False, f"write_failed: {e}"
105
 
 
106
  print(f" ✓ [{WORKER_ID}] {out_name} | {total_tokens:,} tokens")
107
  return True, None
108
 
@@ -141,7 +147,7 @@ def worker_loop():
141
 
142
  total = len(state["shards"]) + len(state.get("queue", []))
143
  done = sum(1 for v in state["shards"].values() if v["status"] == "done")
144
- if done == len(state["shards"]) and not state.get("queue") and total > 0:
145
  print(f" [{WORKER_ID}] All done. Sleeping.")
146
  time.sleep(300)
147
  continue
 
7
  import ctypes
8
  import multiprocessing as mp
9
  from pathlib import Path
10
+ import numpy as np
11
  import pyarrow.parquet as pq
12
  from tokenizers import Tokenizer
13
 
 
73
  def process_shard(name, raw_path, pool):
74
  print(f" [{WORKER_ID}] Processing: {name}")
75
 
76
+ out_name = name.replace(".parquet", ".bin")
77
  out_path = Path(OUT_DIR) / out_name
78
+ tmp_path = Path(OUT_DIR) / f"{out_name}.tmp"
79
  total_tokens = 0
80
 
81
  try:
 
85
  return False, f"read_failed: {e}"
86
 
87
  try:
88
+ with open(tmp_path, "wb") as f:
89
  for batch in pf.iter_batches(batch_size=5_000, columns=["text"]):
90
  texts = batch.column("text").to_pylist()
91
  mid = len(texts) // 2
 
93
  try:
94
  results = pool.map(tokenize_chunk, [texts[:mid], texts[mid:]])
95
  except Exception as e:
96
+ tmp_path.unlink(missing_ok=True)
97
  return False, f"tokenize_failed: {e}"
98
 
99
  for ids in results[0] + results[1]:
100
+ arr = np.array(ids, dtype=np.uint16)
101
+ arr.tofile(f)
102
  total_tokens += len(ids)
103
 
104
  del texts, results
105
  gc.collect()
106
 
107
  except Exception as e:
108
+ tmp_path.unlink(missing_ok=True)
109
  return False, f"write_failed: {e}"
110
 
111
+ tmp_path.rename(out_path) # ← atomic, only visible when complete
112
  print(f" ✓ [{WORKER_ID}] {out_name} | {total_tokens:,} tokens")
113
  return True, None
114
 
 
147
 
148
  total = len(state["shards"]) + len(state.get("queue", []))
149
  done = sum(1 for v in state["shards"].values() if v["status"] == "done")
150
+ if total > 0 and done == total:
151
  print(f" [{WORKER_ID}] All done. Sleeping.")
152
  time.sleep(300)
153
  continue