File size: 1,496 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import random
from typing import Callable, Union, List
from ding.data.buffer import BufferedData
from ding.utils import fastcopy
def padding(policy="random"):
"""
Overview:
Fill the nested buffer list to the same size as the largest list.
The default policy `random` will randomly select data from each group
and fill it into the current group list.
Arguments:
- policy (:obj:`str`): Padding policy, supports `random`, `none`.
"""
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
sampled_data = chain(*args, **kwargs)
if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData):
return sampled_data
max_len = len(max(sampled_data, key=len))
for i, grouped_data in enumerate(sampled_data):
group_len = len(grouped_data)
if group_len == max_len:
continue
for _ in range(max_len - group_len):
if policy == "random":
sampled_data[i].append(fastcopy.copy(random.choice(grouped_data)))
elif policy == "none":
sampled_data[i].append(BufferedData(data=None, index=None, meta=None))
return sampled_data
def _padding(action: str, chain: Callable, *args, **kwargs):
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _padding
|