File size: 1,122 Bytes
9d4f942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from typing import Iterable, Literal
import sys
def flatten(iterable: Iterable, depth = sys.maxsize, return_type: Literal['list', 'generator'] = 'list') -> list | Iterable:
"""
Flatten a nested iterable up to a specified depth.
Args:
iterable (iterable): The iterable to be expanded.
depth (int, optional): The depth to which the iterable should be expanded.
Defaults to 1.
return_type (Literal['list', 'generator'], optional): The type of the return value.
Defaults to 'list'.
Yields:
The expanded elements.
"""
def expand(item, current_depth=0):
if current_depth == depth:
yield item
elif isinstance(item, (list, tuple, set)):
for sub_item in item:
yield from expand(sub_item, current_depth + 1)
else:
yield item
def generator():
for item in iterable:
yield from expand(item)
if return_type == 'list':
return list(generator())
return generator() |