nbroad HF staff commited on
Commit
df70302
1 Parent(s): fb266f1

fix collator

Browse files
Files changed (1) hide show
  1. utils.py +3 -1
utils.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import time
3
  import shutil
4
  from pathlib import Path
 
5
  from typing import Union, Dict, List
6
 
7
  import torch
@@ -321,6 +322,7 @@ def batch_embed(
321
 
322
  start_time = time.time()
323
 
 
324
  for batch in DataLoader(
325
  ds,
326
  batch_size=inference_bs,
@@ -328,8 +330,8 @@ def batch_embed(
328
  num_workers=1,
329
  pin_memory=True,
330
  drop_last=False,
 
331
  ):
332
- batch = collate_fn(batch, device=device)
333
  ids = batch["input_ids"]
334
  mask = batch["attention_mask"]
335
 
 
2
  import time
3
  import shutil
4
  from pathlib import Path
5
+ from functools import partial
6
  from typing import Union, Dict, List
7
 
8
  import torch
 
322
 
323
  start_time = time.time()
324
 
325
+
326
  for batch in DataLoader(
327
  ds,
328
  batch_size=inference_bs,
 
330
  num_workers=1,
331
  pin_memory=True,
332
  drop_last=False,
333
+ collate_fn=partial(collate_fn, device=device)
334
  ):
 
335
  ids = batch["input_ids"]
336
  mask = batch["attention_mask"]
337