File size: 873 Bytes
a80d6bb
 
 
 
 
c74a070
a80d6bb
c74a070
a80d6bb
 
 
 
 
 
c74a070
 
a80d6bb
c74a070
 
 
a80d6bb
c74a070
a80d6bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np


# --- PL-DATAMODULE ---


def get_local_split(items: list, world_size: int, rank: int, seed: int):
    """The local rank only loads a split of the dataset."""
    n_items = len(items)
    items_permute = np.random.RandomState(seed).permutation(items)
    if n_items % world_size == 0:
        padded_items = items_permute
    else:
        padding = np.random.RandomState(seed).choice(
            items, world_size - (n_items % world_size), replace=True
        )
        padded_items = np.concatenate([items_permute, padding])
        assert (
            len(padded_items) % world_size == 0
        ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}"
    n_per_rank = len(padded_items) // world_size
    local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)]

    return local_items