File size: 3,378 Bytes
e8f4897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from torch import stack

def rearranging_splits(datasets, num_training_samples):
    new_datasets = {}
    data_splits = datasets.keys()
    for split in data_splits:
        if split == 'test':
            new_datasets['test'] = datasets['test']
    
        else:
            num_buckets = len(datasets[split][1])
            num_tensors = len(datasets[split][0][0])
            num_samples = sum(datasets[split][1])
            if num_samples < num_training_samples:
                print("set_num_training_samples (%d) should be smaller than the actual %s size (%d)"
                                 % (num_training_samples, split, num_samples))
            new_datasets[split] = [[[[] for _ in range(num_tensors)] for _ in range(num_buckets)], []]
            new_datasets['extra_' + split] = [[[[] for _ in range(num_tensors)] for _ in range(num_buckets)], []]
    for split in data_splits:
        if split == 'test':
            continue
        else:
            curr_bucket_sizes = datasets[split][1]
            curr_samples = datasets[split][0]
            num_tensors = len(datasets[split][0][0])
            curr_num_samples = sum(curr_bucket_sizes)
            sample_indices_in_buckets = {}
            i = 0
            for bucket_idx, bucket_size in enumerate(curr_bucket_sizes):
                for sample_idx in range(bucket_size):
                    sample_indices_in_buckets[i] = (bucket_idx, sample_idx)
                    i += 1
            rng = np.arange(curr_num_samples)
            rng = np.random.permutation(rng)
            sample_indices = {}
            sample_indices[split] = [sample_indices_in_buckets[key] for key in rng[:num_training_samples]]
            sample_indices['extra_' + split] = [sample_indices_in_buckets[key] for key in rng[num_training_samples:]]
            if len(sample_indices['extra_' + split]) == 0:
                if len(sample_indices[split]) > 1:
                    sample_indices['extra_' + split].append(sample_indices[split].pop(-1))
                else:
                    sample_indices['extra_' + split].append(sample_indices[split][0])

            for key, indices in sample_indices.items():
                for bucket_idx, sample_idx in indices:
                    curr_bucket = curr_samples[bucket_idx]
                    for tensor_idx, tensor in enumerate(curr_bucket):
                        new_datasets[key][0][bucket_idx][tensor_idx].append(tensor[sample_idx])
    del datasets
    new_splits = []
    new_splits += [split for split in data_splits if split != 'test']
    new_splits += ['extra_' + split for split in data_splits if split != 'test']

    for split in new_splits:
        for bucket_idx in range(num_buckets):
            for tensor_idx in range(num_tensors):
                if len(new_datasets[split][0][bucket_idx][tensor_idx]) > 0:
                    new_datasets[split][0][bucket_idx][tensor_idx] = stack(new_datasets[split][0][bucket_idx][tensor_idx])
                else:
                    new_datasets[split][0][bucket_idx] = (1,1)
                    break
            # set lengths of buckets
            if new_datasets[split][0][bucket_idx] == (1,1):
                new_datasets[split][1].append(0)
            else:
                new_datasets[split][1].append(len(new_datasets[split][0][bucket_idx][tensor_idx]))
    return new_datasets