File size: 400 Bytes
81d3845 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
"""
helper util to calculate dataset lengths
"""
import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
else:
lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
return lengths
|