File size: 1,122 Bytes
9d4f942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Iterable, Literal
import sys


def flatten(iterable: Iterable, depth = sys.maxsize, return_type: Literal['list', 'generator'] = 'list') -> list | Iterable:
    """
    Flatten a nested iterable up to a specified depth.

    Args:
        iterable (iterable): The iterable to be expanded.
        depth (int, optional): The depth to which the iterable should be expanded.
                               Defaults to 1.
        return_type (Literal['list', 'generator'], optional): The type of the return value.
                                                              Defaults to 'list'.
    Yields:
        The expanded elements.
    """

    def expand(item, current_depth=0):
        if current_depth == depth:
            yield item
        elif isinstance(item, (list, tuple, set)):
            for sub_item in item:
                yield from expand(sub_item, current_depth + 1)
        else:
            yield item

    def generator():
        for item in iterable:
            yield from expand(item)
    
    if return_type == 'list':
        return list(generator())
    return generator()