nbroad HF staff commited on
Commit
9f6b9a6
1 Parent(s): d4479f1

store locally

Browse files
Files changed (1) hide show
  1. utils.py +49 -44
utils.py CHANGED
@@ -19,7 +19,6 @@ from optimum.onnxruntime import (
19
 
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
22
-
23
  opt_configs = {
24
  "O2": AutoOptimizationConfig.O2(),
25
  "O3": AutoOptimizationConfig.O3(),
@@ -108,7 +107,8 @@ def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"
108
  if ds_config == "":
109
  ds_config = None
110
 
111
- ds = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
 
112
 
113
  return ds
114
 
@@ -212,22 +212,34 @@ def tokenize(
212
  )
213
 
214
 
215
- def collate_fn(examples, tokenizer=None, padding=None, device=None):
216
  try:
217
  keys = examples[0].keys()
218
  except KeyError:
219
  print(examples)
220
  else:
221
  batch = {k: [] for k in examples[0].keys()}
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
 
224
- for example in examples:
225
- for k, v in example.items():
226
- batch[k].append(v)
227
 
228
- return {
229
- k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
230
- }
231
 
232
  @torch.inference_mode()
233
  def batch_embed(
@@ -293,16 +305,16 @@ def batch_embed(
293
 
294
  repo = init_git_repo(new_dataset_id)
295
 
296
- ds = ds.map(
297
- tokenize,
298
- batched=True,
299
- batch_size=map_batch_size,
300
- fn_kwargs={
301
- "tokenizer": tokenizer,
302
- "column_name": column_name,
303
- "padding": "max_length" if opt_level == "O4" else True,
304
- },
305
- )
306
 
307
  embeds = []
308
  texts = []
@@ -327,10 +339,15 @@ def batch_embed(
327
  ds,
328
  batch_size=inference_bs,
329
  shuffle=False,
330
- num_workers=1,
331
  pin_memory=True,
332
  drop_last=False,
333
- collate_fn=partial(collate_fn, device=device)
 
 
 
 
 
334
  ):
335
  ids = batch["input_ids"].to(device)
336
  mask = batch["attention_mask"].to(device)
@@ -354,7 +371,7 @@ def batch_embed(
354
 
355
  # Periodically upload to the hub
356
  if len(embeds) > upload_batch_size:
357
- push_to_repo(repo, last_count, current_count, embeds, texts, api)
358
  embeds = []
359
  texts = []
360
  last_count = current_count
@@ -372,7 +389,7 @@ def batch_embed(
372
 
373
  # If there are any remaining embeddings, upload them
374
  if len(embeds) > 0:
375
- push_to_repo(repo, last_count, current_count, embeds, texts, api)
376
 
377
  return current_count - num2skip, time_taken
378
 
@@ -472,27 +489,15 @@ def push_to_repo(
472
  files = sorted(list(data_dir.glob("*.parquet")))
473
 
474
 
475
- if len(files) == 1:
476
- api.upload_folder(
477
- folder_path=str(data_dir),
478
- repo_id=repo_id,
479
- repo_type="dataset",
480
- run_as_future=True,
481
- token=os.environ["HF_TOKEN"],
482
- commit_message=f"Embedded examples {last_count} thru {current_count} with folder",
483
- )
484
-
485
- else:
486
-
487
- api.upload_file(
488
- path_or_fileobj=filepath,
489
- path_in_repo=f"data/{filename}",
490
- repo_id=repo_id,
491
- repo_type="dataset",
492
- run_as_future=True,
493
- token=os.environ["HF_TOKEN"],
494
- commit_message=f"Embedded examples {last_count} thru {current_count}",
495
- )
496
 
497
 
498
  # Delete old files
 
19
 
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
 
22
  opt_configs = {
23
  "O2": AutoOptimizationConfig.O2(),
24
  "O3": AutoOptimizationConfig.O3(),
 
107
  if ds_config == "":
108
  ds_config = None
109
 
110
+ ds = load_dataset(ds_name, ds_config, split=ds_split, )
111
+ #streaming=True)
112
 
113
  return ds
114
 
 
212
  )
213
 
214
 
215
+ def collate_fn(examples, tokenizer=None, padding=None, column_name="text"):
216
  try:
217
  keys = examples[0].keys()
218
  except KeyError:
219
  print(examples)
220
  else:
221
  batch = {k: [] for k in examples[0].keys()}
222
+
223
+ tokenized = tokenizer(
224
+ [x[column_name] for x in examples],
225
+ truncation=True,
226
+ padding=padding,
227
+ max_length=512,
228
+ return_tensors="pt"
229
+ )
230
+
231
+ tokenized[column_name] = [x[column_name] for x in examples]
232
+
233
+ return tokenized
234
 
235
 
236
+ # for example in examples:
237
+ # for k, v in example.items():
238
+ # batch[k].append(v)
239
 
240
+ # return {
241
+ # k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
242
+ # }
243
 
244
  @torch.inference_mode()
245
  def batch_embed(
 
305
 
306
  repo = init_git_repo(new_dataset_id)
307
 
308
+ # ds = ds.map(
309
+ # tokenize,
310
+ # batched=True,
311
+ # batch_size=map_batch_size,
312
+ # fn_kwargs={
313
+ # "tokenizer": tokenizer,
314
+ # "column_name": column_name,
315
+ # "padding": "max_length" if opt_level == "O4" else True,
316
+ # },
317
+ # )
318
 
319
  embeds = []
320
  texts = []
 
339
  ds,
340
  batch_size=inference_bs,
341
  shuffle=False,
342
+ num_workers=2,
343
  pin_memory=True,
344
  drop_last=False,
345
+ collate_fn=partial(
346
+ collate_fn,
347
+ column_name=column_name,
348
+ tokenizer=tokenizer,
349
+ padding="max_length" if opt_level == "O4" else True
350
+ )
351
  ):
352
  ids = batch["input_ids"].to(device)
353
  mask = batch["attention_mask"].to(device)
 
371
 
372
  # Periodically upload to the hub
373
  if len(embeds) > upload_batch_size:
374
+ push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api)
375
  embeds = []
376
  texts = []
377
  last_count = current_count
 
389
 
390
  # If there are any remaining embeddings, upload them
391
  if len(embeds) > 0:
392
+ push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api)
393
 
394
  return current_count - num2skip, time_taken
395
 
 
489
  files = sorted(list(data_dir.glob("*.parquet")))
490
 
491
 
492
+ api.upload_file(
493
+ path_or_fileobj=filepath,
494
+ path_in_repo=f"data/{filename}",
495
+ repo_id=repo_id,
496
+ repo_type="dataset",
497
+ run_as_future=True,
498
+ token=os.environ["HF_TOKEN"],
499
+ commit_message=f"Embedded examples {last_count} thru {current_count}",
500
+ )
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
 
503
  # Delete old files