Fucius's picture
Upload 422 files
2eafbc4 verified
from typing import Generator, Iterable, List, TypeVar, Union
B = TypeVar("B")
def calculate_input_elements(input_value: Union[B, List[B]]) -> int:
return len(input_value) if issubclass(type(input_value), list) else 1
def create_batches(
sequence: Iterable[B], batch_size: int
) -> Generator[List[B], None, None]:
batch_size = max(batch_size, 1)
current_batch = []
for element in sequence:
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
current_batch.append(element)
if len(current_batch) > 0:
yield current_batch