Spaces:
Runtime error
Runtime error
fix collator
Browse files
utils.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import time
|
3 |
import shutil
|
4 |
from pathlib import Path
|
|
|
5 |
from typing import Union, Dict, List
|
6 |
|
7 |
import torch
|
@@ -321,6 +322,7 @@ def batch_embed(
|
|
321 |
|
322 |
start_time = time.time()
|
323 |
|
|
|
324 |
for batch in DataLoader(
|
325 |
ds,
|
326 |
batch_size=inference_bs,
|
@@ -328,8 +330,8 @@ def batch_embed(
|
|
328 |
num_workers=1,
|
329 |
pin_memory=True,
|
330 |
drop_last=False,
|
|
|
331 |
):
|
332 |
-
batch = collate_fn(batch, device=device)
|
333 |
ids = batch["input_ids"]
|
334 |
mask = batch["attention_mask"]
|
335 |
|
|
|
2 |
import time
|
3 |
import shutil
|
4 |
from pathlib import Path
|
5 |
+
from functools import partial
|
6 |
from typing import Union, Dict, List
|
7 |
|
8 |
import torch
|
|
|
322 |
|
323 |
start_time = time.time()
|
324 |
|
325 |
+
|
326 |
for batch in DataLoader(
|
327 |
ds,
|
328 |
batch_size=inference_bs,
|
|
|
330 |
num_workers=1,
|
331 |
pin_memory=True,
|
332 |
drop_last=False,
|
333 |
+
collate_fn=partial(collate_fn, device=device)
|
334 |
):
|
|
|
335 |
ids = batch["input_ids"]
|
336 |
mask = batch["attention_mask"]
|
337 |
|