Spaces:
Runtime error
Runtime error
fix batching v2
Browse files
utils.py
CHANGED
@@ -211,6 +211,17 @@ def tokenize(
|
|
211 |
)
|
212 |
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
@torch.inference_mode()
|
215 |
def batch_embed(
|
216 |
ds: datasets.IterableDataset,
|
@@ -308,18 +319,20 @@ def batch_embed(
|
|
308 |
ds,
|
309 |
batch_size=inference_bs,
|
310 |
shuffle=False,
|
311 |
-
num_workers=
|
312 |
pin_memory=True,
|
313 |
drop_last=False,
|
314 |
):
|
315 |
-
|
316 |
-
|
|
|
|
|
317 |
t_ids = torch.zeros_like(ids)
|
318 |
|
319 |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
|
320 |
|
321 |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
|
322 |
-
texts.extend(
|
323 |
|
324 |
current_count += ids.shape[0]
|
325 |
|
|
|
211 |
)
|
212 |
|
213 |
|
214 |
+
def collate_fn(examples, tokenizer=None, padding=None, device=None):
|
215 |
+
batch = {k: [] for k in examples[0].keys()}
|
216 |
+
|
217 |
+
for example in examples:
|
218 |
+
for k, v in example.items():
|
219 |
+
batch[k].append(v)
|
220 |
+
|
221 |
+
return {
|
222 |
+
k: torch.tensor(v, dtype=torch.long, device=device) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
|
223 |
+
}
|
224 |
+
|
225 |
@torch.inference_mode()
|
226 |
def batch_embed(
|
227 |
ds: datasets.IterableDataset,
|
|
|
319 |
ds,
|
320 |
batch_size=inference_bs,
|
321 |
shuffle=False,
|
322 |
+
num_workers=1,
|
323 |
pin_memory=True,
|
324 |
drop_last=False,
|
325 |
):
|
326 |
+
batch = collate_fn(batch, device=device)
|
327 |
+
ids = batch["input_ids"]
|
328 |
+
mask = batch["attention_mask"]
|
329 |
+
|
330 |
t_ids = torch.zeros_like(ids)
|
331 |
|
332 |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
|
333 |
|
334 |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
|
335 |
+
texts.extend([b[column_name] for b in batch])
|
336 |
|
337 |
current_count += ids.shape[0]
|
338 |
|