JustinLin610
update
10b0761
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
from libcpp cimport bool as bool_t
ctypedef int64_t DTYPE_t
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
np.ndarray[int64_t, ndim=1] indices,
np.ndarray[int64_t, ndim=1] num_tokens_vec,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
if indices.shape[0] == 0:
return []
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
f"Sentences lengths should not exceed max_tokens={max_tokens}"
)
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
np.zeros(indices_len, dtype=np.int32)
cdef int32_t[:] batches_ends_view = batches_ends
cdef int64_t[:] num_tokens_view = num_tokens_vec
cdef int32_t pos = 0
cdef int32_t new_batch_end = 0
cdef int64_t new_batch_max_tokens = 0
cdef int32_t new_batch_sentences = 0
cdef int64_t new_batch_num_tokens = 0
cdef bool_t overflow = False
cdef bool_t size_matches_with_bsz_mult = False
cdef int32_t batches_count = 0
cdef int32_t batch_start = 0
cdef int64_t tail_max_tokens = 0
cdef int64_t batch_max_tokens = 0
for pos in range(indices_len):
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
# and tail [batch_end:pos].
# 1) Every time when (batch + tail) forms a valid batch
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
# we finalize running batch, and tail becomes a new batch.
# 3) There is a corner case when tail also violates constraints.
# In that situation [batch_end:pos-1] (tail without the current pos)
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
#
# Important: For the sake of performance try to avoid using function calls within this loop.
tail_max_tokens = tail_max_tokens \
if tail_max_tokens > num_tokens_view[pos] \
else num_tokens_view[pos]
new_batch_end = pos + 1
new_batch_max_tokens = batch_max_tokens \
if batch_max_tokens > tail_max_tokens \
else tail_max_tokens
new_batch_sentences = new_batch_end - batch_start
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
overflow = (new_batch_sentences > max_sentences > 0 or
new_batch_num_tokens > max_tokens > 0)
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
new_batch_sentences % bsz_mult == 0)
if overflow:
tail_num_tokens = tail_max_tokens * \
(new_batch_end - batches_ends_view[batches_count])
tail_overflow = tail_num_tokens > max_tokens > 0
# In case of a tail overflow finalize two batches
if tail_overflow:
batches_count += 1
batches_ends_view[batches_count] = pos
tail_max_tokens = num_tokens_view[pos]
batch_start = batches_ends_view[batches_count]
batches_count += 1
new_batch_max_tokens = tail_max_tokens
if overflow or size_matches_with_bsz_mult:
batches_ends_view[batches_count] = new_batch_end
batch_max_tokens = new_batch_max_tokens
tail_max_tokens = 0
if batches_ends_view[batches_count] != indices_len:
batches_count += 1
# Memory and time-efficient split
return np.split(indices, batches_ends[:batches_count])
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
dtype=np.int64)
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
cdef int64_t pos
for pos in range(indices_len):
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
max_sentences, bsz_mult,)
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
int64_t num_sentences,
int64_t num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
return i
return -1
@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
if shape_idx == -1:
batches.append(batch)
batch = []
sample_lens = []
sample_len = 0
shapes_view = fixed_shapes_sorted
elif shape_idx > 0:
# small optimization for the next call to _find_valid_shape
shapes_view = shapes_view[shape_idx:]
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches