from typing import Iterable, Literal import sys def flatten(iterable: Iterable, depth = sys.maxsize, return_type: Literal['list', 'generator'] = 'list') -> list | Iterable: """ Flatten a nested iterable up to a specified depth. Args: iterable (iterable): The iterable to be expanded. depth (int, optional): The depth to which the iterable should be expanded. Defaults to 1. return_type (Literal['list', 'generator'], optional): The type of the return value. Defaults to 'list'. Yields: The expanded elements. """ def expand(item, current_depth=0): if current_depth == depth: yield item elif isinstance(item, (list, tuple, set)): for sub_item in item: yield from expand(sub_item, current_depth + 1) else: yield item def generator(): for item in iterable: yield from expand(item) if return_type == 'list': return list(generator()) return generator()