File size: 1,496 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import random
from typing import Callable, Union, List

from ding.data.buffer import BufferedData
from ding.utils import fastcopy


def padding(policy="random"):
    """
    Overview:
        Fill the nested buffer list to the same size as the largest list.
        The default policy `random` will randomly select data from each group
        and fill it into the current group list.
    Arguments:
        - policy (:obj:`str`): Padding policy, supports `random`, `none`.
    """

    def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
        sampled_data = chain(*args, **kwargs)
        if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData):
            return sampled_data
        max_len = len(max(sampled_data, key=len))
        for i, grouped_data in enumerate(sampled_data):
            group_len = len(grouped_data)
            if group_len == max_len:
                continue
            for _ in range(max_len - group_len):
                if policy == "random":
                    sampled_data[i].append(fastcopy.copy(random.choice(grouped_data)))
                elif policy == "none":
                    sampled_data[i].append(BufferedData(data=None, index=None, meta=None))

        return sampled_data

    def _padding(action: str, chain: Callable, *args, **kwargs):
        if action == "sample":
            return sample(chain, *args, **kwargs)
        return chain(*args, **kwargs)

    return _padding