Neon-tech commited on
Commit
1807b95
Β·
verified Β·
1 Parent(s): 558cde6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -68
app.py CHANGED
@@ -57,28 +57,31 @@ def claim_shard(state):
57
  return None, None
58
 
59
  # ── Tokenize chunk (subprocess) ───────────────────────────────────────────────
60
- def tokenize_chunk(args):
61
- texts, tok_path = args
62
- tokenizer = Tokenizer.from_file(tok_path)
 
 
 
 
63
  results = []
64
  for text in texts:
65
  if not text or not text.strip():
66
  continue
67
- enc = tokenizer.encode(text)
68
  ids = enc.ids
69
  if len(ids) >= 2:
70
  results.append(ids)
71
  return results
72
 
73
  # ── Process shard using both cores ───────────────────────────────────────────
74
- def process_shard(name, raw_path):
75
  print(f" [{WORKER_ID}] Processing: {name}")
76
 
77
  try:
78
  df = pd.read_parquet(raw_path, columns=["text"])
79
  except Exception as e:
80
- print(f" βœ— Read failed: {e}")
81
- return False
82
 
83
  total = len(df)
84
  print(f" [{WORKER_ID}] {total:,} rows β€” splitting across 2 cores")
@@ -89,34 +92,35 @@ def process_shard(name, raw_path):
89
  del df
90
  gc.collect()
91
 
92
- with mp.Pool(processes=2) as pool:
93
- results = pool.map(tokenize_chunk, [
94
- (texts1, TOK_PATH),
95
- (texts2, TOK_PATH),
96
- ])
97
 
98
  all_ids = results[0] + results[1]
99
  del results, texts1, texts2
100
  gc.collect()
101
 
102
  if not all_ids:
103
- print(f" βœ— No tokens produced")
104
- return False
105
 
106
  out_name = name.replace(".parquet", ".jsonl")
107
  out_path = Path(OUT_DIR) / out_name
108
  total_tokens = 0
109
 
110
- with open(out_path, "w", encoding="utf-8") as f:
111
- for ids in all_ids:
112
- f.write(json.dumps({"input_ids": ids}) + "\n")
113
- total_tokens += len(ids)
 
 
 
114
 
115
  del all_ids
116
  gc.collect()
117
 
118
  print(f" βœ“ [{WORKER_ID}] {out_name} | {total_tokens:,} tokens")
119
- return True
120
 
121
  # ── Worker loop ───────────────────────────────────────────────────────────────
122
  def worker_loop():
@@ -126,58 +130,68 @@ def worker_loop():
126
  del tok
127
  gc.collect()
128
 
129
- while True:
130
- if not os.path.exists(STATE_FILE):
131
- print(f" [{WORKER_ID}] Waiting for state.json...")
132
- time.sleep(POLL_INTERVAL)
133
- continue
134
 
135
- try:
136
- state = load_state()
137
- except Exception as e:
138
- print(f" [{WORKER_ID}] State read error: {e}")
139
- time.sleep(POLL_INTERVAL)
140
- continue
141
-
142
- total = len(state["shards"]) + len(state.get("queue", []))
143
- done = sum(1 for v in state["shards"].values() if v["status"] == "done")
144
- if done == len(state["shards"]) and not state.get("queue") and total > 0:
145
- print(f" [{WORKER_ID}] All done. Sleeping.")
146
- time.sleep(300)
147
- continue
148
-
149
- name, raw_path = claim_shard(state)
150
-
151
- if not name:
152
- print(f" [{WORKER_ID}] Nothing ready β€” polling in {POLL_INTERVAL}s")
153
- time.sleep(POLL_INTERVAL)
154
- continue
155
-
156
- print(f" [{WORKER_ID}] Claimed: {name}")
157
- success = process_shard(name, raw_path)
158
-
159
- try:
160
- state = load_state()
161
- except Exception:
162
- pass
163
-
164
- if success:
165
- state["shards"][name]["status"] = "done"
166
- else:
167
- state["shards"][name]["status"] = "pending"
168
- state["shards"][name]["worker"] = None
169
- state["shards"][name]["claimed_at"] = None
170
-
171
- save_state(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- try:
174
- raw_path.unlink()
175
- print(f" [{WORKER_ID}] Deleted: {raw_path.name}")
176
- except Exception as e:
177
- print(f" [{WORKER_ID}] Delete failed: {e}")
178
 
179
- gc.collect()
180
- time.sleep(5)
 
181
 
182
  # ── Entry point ───────────────────────────────────────────────────────────────
183
  if __name__ == "__main__":
 
57
  return None, None
58
 
59
  # ── Tokenize chunk (subprocess) ───────────────────────────────────────────────
60
+ _worker_tokenizer = None
61
+
62
+ def init_worker(tok_path):
63
+ global _worker_tokenizer
64
+ _worker_tokenizer = Tokenizer.from_file(tok_path)
65
+
66
+ def tokenize_chunk(texts):
67
  results = []
68
  for text in texts:
69
  if not text or not text.strip():
70
  continue
71
+ enc = _worker_tokenizer.encode(text)
72
  ids = enc.ids
73
  if len(ids) >= 2:
74
  results.append(ids)
75
  return results
76
 
77
  # ── Process shard using both cores ───────────────────────────────────────────
78
+ def process_shard(name, raw_path, pool):
79
  print(f" [{WORKER_ID}] Processing: {name}")
80
 
81
  try:
82
  df = pd.read_parquet(raw_path, columns=["text"])
83
  except Exception as e:
84
+ return False, f"read_failed: {e}"
 
85
 
86
  total = len(df)
87
  print(f" [{WORKER_ID}] {total:,} rows β€” splitting across 2 cores")
 
92
  del df
93
  gc.collect()
94
 
95
+ try:
96
+ results = pool.map(tokenize_chunk, [texts1, texts2])
97
+ except Exception as e:
98
+ return False, f"tokenize_failed: {e}"
 
99
 
100
  all_ids = results[0] + results[1]
101
  del results, texts1, texts2
102
  gc.collect()
103
 
104
  if not all_ids:
105
+ return False, "no_tokens_produced"
 
106
 
107
  out_name = name.replace(".parquet", ".jsonl")
108
  out_path = Path(OUT_DIR) / out_name
109
  total_tokens = 0
110
 
111
+ try:
112
+ with open(out_path, "w", encoding="utf-8") as f:
113
+ for ids in all_ids:
114
+ f.write(json.dumps({"input_ids": ids}) + "\n")
115
+ total_tokens += len(ids)
116
+ except Exception as e:
117
+ return False, f"write_failed: {e}"
118
 
119
  del all_ids
120
  gc.collect()
121
 
122
  print(f" βœ“ [{WORKER_ID}] {out_name} | {total_tokens:,} tokens")
123
+ return True, None
124
 
125
  # ── Worker loop ───────────────────────────────────────────────────────────────
126
  def worker_loop():
 
130
  del tok
131
  gc.collect()
132
 
133
+ pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,))
134
+ print(f"βœ“ [{WORKER_ID}] Worker pool ready")
 
 
 
135
 
136
+ try:
137
+ while True:
138
+ if not os.path.exists(STATE_FILE):
139
+ print(f" [{WORKER_ID}] Waiting for state.json...")
140
+ time.sleep(POLL_INTERVAL)
141
+ continue
142
+
143
+ try:
144
+ state = load_state()
145
+ except Exception as e:
146
+ print(f" [{WORKER_ID}] State read error: {e}")
147
+ time.sleep(POLL_INTERVAL)
148
+ continue
149
+
150
+ total = len(state["shards"]) + len(state.get("queue", []))
151
+ done = sum(1 for v in state["shards"].values() if v["status"] == "done")
152
+ if done == len(state["shards"]) and not state.get("queue") and total > 0:
153
+ print(f" [{WORKER_ID}] All done. Sleeping.")
154
+ time.sleep(300)
155
+ continue
156
+
157
+ name, raw_path = claim_shard(state)
158
+
159
+ if not name:
160
+ print(f" [{WORKER_ID}] Nothing ready β€” polling in {POLL_INTERVAL}s")
161
+ time.sleep(POLL_INTERVAL)
162
+ continue
163
+
164
+ print(f" [{WORKER_ID}] Claimed: {name}")
165
+ success, error = process_shard(name, raw_path, pool)
166
+
167
+ try:
168
+ state = load_state()
169
+ except Exception:
170
+ pass
171
+
172
+ if success:
173
+ state["shards"][name]["status"] = "done"
174
+ state["shards"][name]["error"] = None
175
+ save_state(state)
176
+ try:
177
+ raw_path.unlink()
178
+ print(f" [{WORKER_ID}] Deleted: {raw_path.name}")
179
+ except Exception as e:
180
+ print(f" [{WORKER_ID}] Delete failed: {e}")
181
+ else:
182
+ state["shards"][name]["status"] = "pending"
183
+ state["shards"][name]["worker"] = None
184
+ state["shards"][name]["claimed_at"] = None
185
+ state["shards"][name]["error"] = error
186
+ save_state(state)
187
+ print(f" [{WORKER_ID}] Shard failed ({error}), left on disk for retry: {name}")
188
 
189
+ gc.collect()
190
+ time.sleep(5)
 
 
 
191
 
192
+ finally:
193
+ pool.terminate()
194
+ pool.join()
195
 
196
  # ── Entry point ───────────────────────────────────────────────────────────────
197
  if __name__ == "__main__":