winglian's picture
Vram fix attempt (#1164) [skip ci]
32580c1 unverified
raw
history blame
559 Bytes
"""
helper util to calculate dataset lengths
"""
import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths