nbroad HF staff commited on
Commit
6c6d3ac
1 Parent(s): af7a07a

fix batching v2

Browse files
Files changed (1) hide show
  1. utils.py +17 -4
utils.py CHANGED
@@ -211,6 +211,17 @@ def tokenize(
211
  )
212
 
213
 
 
 
 
 
 
 
 
 
 
 
 
214
  @torch.inference_mode()
215
  def batch_embed(
216
  ds: datasets.IterableDataset,
@@ -308,18 +319,20 @@ def batch_embed(
308
  ds,
309
  batch_size=inference_bs,
310
  shuffle=False,
311
- num_workers=2,
312
  pin_memory=True,
313
  drop_last=False,
314
  ):
315
- ids = torch.tensor(batch["input_ids"], device=device)
316
- mask = torch.tensor(batch["attention_mask"], device=device)
 
 
317
  t_ids = torch.zeros_like(ids)
318
 
319
  outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
320
 
321
  embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
322
- texts.extend(batch[column_name])
323
 
324
  current_count += ids.shape[0]
325
 
 
211
  )
212
 
213
 
214
+ def collate_fn(examples, tokenizer=None, padding=None, device=None):
215
+ batch = {k: [] for k in examples[0].keys()}
216
+
217
+ for example in examples:
218
+ for k, v in example.items():
219
+ batch[k].append(v)
220
+
221
+ return {
222
+ k: torch.tensor(v, dtype=torch.long, device=device) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
223
+ }
224
+
225
  @torch.inference_mode()
226
  def batch_embed(
227
  ds: datasets.IterableDataset,
 
319
  ds,
320
  batch_size=inference_bs,
321
  shuffle=False,
322
+ num_workers=1,
323
  pin_memory=True,
324
  drop_last=False,
325
  ):
326
+ batch = collate_fn(batch, device=device)
327
+ ids = batch["input_ids"]
328
+ mask = batch["attention_mask"]
329
+
330
  t_ids = torch.zeros_like(ids)
331
 
332
  outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
333
 
334
  embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
335
+ texts.extend([b[column_name] for b in batch])
336
 
337
  current_count += ids.shape[0]
338