from typing import Iterable, Literal | |
import sys | |
def flatten(iterable: Iterable, depth = sys.maxsize, return_type: Literal['list', 'generator'] = 'list') -> list | Iterable: | |
""" | |
Flatten a nested iterable up to a specified depth. | |
Args: | |
iterable (iterable): The iterable to be expanded. | |
depth (int, optional): The depth to which the iterable should be expanded. | |
Defaults to 1. | |
return_type (Literal['list', 'generator'], optional): The type of the return value. | |
Defaults to 'list'. | |
Yields: | |
The expanded elements. | |
""" | |
def expand(item, current_depth=0): | |
if current_depth == depth: | |
yield item | |
elif isinstance(item, (list, tuple, set)): | |
for sub_item in item: | |
yield from expand(sub_item, current_depth + 1) | |
else: | |
yield item | |
def generator(): | |
for item in iterable: | |
yield from expand(item) | |
if return_type == 'list': | |
return list(generator()) | |
return generator() |