|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
cimport cython |
|
cimport numpy as np |
|
|
|
from libc.stdint cimport int32_t, int64_t |
|
from libcpp cimport bool as bool_t |
|
|
|
ctypedef int64_t DTYPE_t |
|
|
|
@cython.cdivision(True) |
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
cpdef list batch_by_size_vec( |
|
np.ndarray[int64_t, ndim=1] indices, |
|
np.ndarray[int64_t, ndim=1] num_tokens_vec, |
|
int64_t max_tokens, |
|
int64_t max_sentences, |
|
int32_t bsz_mult, |
|
): |
|
if indices.shape[0] == 0: |
|
return [] |
|
|
|
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( |
|
f"Sentences lengths should not exceed max_tokens={max_tokens}" |
|
) |
|
|
|
cdef int32_t indices_len = indices.shape[0] |
|
cdef np.ndarray[int32_t, ndim=1] batches_ends = \ |
|
np.zeros(indices_len, dtype=np.int32) |
|
cdef int32_t[:] batches_ends_view = batches_ends |
|
cdef int64_t[:] num_tokens_view = num_tokens_vec |
|
|
|
cdef int32_t pos = 0 |
|
cdef int32_t new_batch_end = 0 |
|
|
|
cdef int64_t new_batch_max_tokens = 0 |
|
cdef int32_t new_batch_sentences = 0 |
|
cdef int64_t new_batch_num_tokens = 0 |
|
|
|
cdef bool_t overflow = False |
|
cdef bool_t size_matches_with_bsz_mult = False |
|
|
|
cdef int32_t batches_count = 0 |
|
cdef int32_t batch_start = 0 |
|
cdef int64_t tail_max_tokens = 0 |
|
cdef int64_t batch_max_tokens = 0 |
|
|
|
for pos in range(indices_len): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tail_max_tokens = tail_max_tokens \ |
|
if tail_max_tokens > num_tokens_view[pos] \ |
|
else num_tokens_view[pos] |
|
new_batch_end = pos + 1 |
|
new_batch_max_tokens = batch_max_tokens \ |
|
if batch_max_tokens > tail_max_tokens \ |
|
else tail_max_tokens |
|
new_batch_sentences = new_batch_end - batch_start |
|
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens |
|
|
|
overflow = (new_batch_sentences > max_sentences > 0 or |
|
new_batch_num_tokens > max_tokens > 0) |
|
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or |
|
new_batch_sentences % bsz_mult == 0) |
|
|
|
if overflow: |
|
tail_num_tokens = tail_max_tokens * \ |
|
(new_batch_end - batches_ends_view[batches_count]) |
|
tail_overflow = tail_num_tokens > max_tokens > 0 |
|
|
|
if tail_overflow: |
|
batches_count += 1 |
|
batches_ends_view[batches_count] = pos |
|
tail_max_tokens = num_tokens_view[pos] |
|
batch_start = batches_ends_view[batches_count] |
|
batches_count += 1 |
|
new_batch_max_tokens = tail_max_tokens |
|
|
|
if overflow or size_matches_with_bsz_mult: |
|
batches_ends_view[batches_count] = new_batch_end |
|
batch_max_tokens = new_batch_max_tokens |
|
tail_max_tokens = 0 |
|
if batches_ends_view[batches_count] != indices_len: |
|
batches_count += 1 |
|
|
|
return np.split(indices, batches_ends[:batches_count]) |
|
|
|
|
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
cpdef list batch_by_size_fn( |
|
np.ndarray[DTYPE_t, ndim=1] indices, |
|
num_tokens_fn, |
|
int64_t max_tokens, |
|
int64_t max_sentences, |
|
int32_t bsz_mult, |
|
): |
|
cdef int32_t indices_len = indices.shape[0] |
|
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, |
|
dtype=np.int64) |
|
cdef DTYPE_t[:] indices_view = indices |
|
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec |
|
cdef int64_t pos |
|
for pos in range(indices_len): |
|
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) |
|
return batch_by_size_vec(indices, num_tokens_vec, max_tokens, |
|
max_sentences, bsz_mult,) |
|
|
|
|
|
cdef _find_valid_shape( |
|
DTYPE_t[:, :] shapes_view, |
|
int64_t num_sentences, |
|
int64_t num_tokens, |
|
): |
|
"""Return index of first valid shape of -1 if none is found.""" |
|
for i in range(shapes_view.shape[0]): |
|
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: |
|
return i |
|
return -1 |
|
|
|
|
|
@cython.cdivision(True) |
|
cpdef list batch_fixed_shapes_fast( |
|
np.ndarray[DTYPE_t, ndim=1] indices, |
|
num_tokens_fn, |
|
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, |
|
): |
|
cdef int64_t sample_len = 0 |
|
cdef list sample_lens = [] |
|
cdef list batch = [] |
|
cdef list batches = [] |
|
cdef int64_t mod_len |
|
cdef int64_t i |
|
cdef int64_t idx |
|
cdef int64_t num_tokens |
|
cdef DTYPE_t[:] indices_view = indices |
|
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted |
|
|
|
for i in range(len(indices_view)): |
|
idx = indices_view[i] |
|
num_tokens = num_tokens_fn(idx) |
|
sample_lens.append(num_tokens) |
|
sample_len = max(sample_len, num_tokens) |
|
|
|
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) |
|
if shape_idx == -1: |
|
batches.append(batch) |
|
batch = [] |
|
sample_lens = [] |
|
sample_len = 0 |
|
shapes_view = fixed_shapes_sorted |
|
elif shape_idx > 0: |
|
|
|
shapes_view = shapes_view[shape_idx:] |
|
|
|
batch.append(idx) |
|
|
|
if len(batch) > 0: |
|
batches.append(batch) |
|
|
|
return batches |
|
|