File size: 5,241 Bytes
			
			| 641e6f7 0c2a630 797f3dd 1470650 641e6f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | # pylint: skip-file
"""
Multipack Batch Sampler
"""
import logging
import math
import os
from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
    # First-fit-decreasing bin packing
    # Check if a[] could fit in n bins with capacity c
    # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
    a = np.sort(a)[::-1]
    bins = np.full((n,), c, dtype=a.dtype)
    for size in a:
        not_found = True
        for idx in range(n):
            if bins[idx] >= size:
                bins[idx] -= size
                not_found = False
                break
        if not_found:
            return False
    return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
    # First-fit-decreasing bin packing (with result return)
    indices = np.argsort(a)[::-1]
    a = a[indices]
    bins: List[Any] = []
    bins_result: List[Any] = []
    for a_id, size in enumerate(a):
        add_new = True
        for idx in range(len(bins)):
            if bins[idx] >= size:
                bins[idx] -= size
                bins_result[idx].append(indices[a_id] + start_index)
                add_new = False
                break
        if add_new:
            bins.append(c - size)
            bins_result.append([indices[a_id] + start_index])
    return bins_result
@numba.njit
def allocate(
    lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
    # Dynamic batch allocator, similar to Multifit
    # https://en.wikipedia.org/wiki/Multifit_algorithm
    # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
    s = 0
    start_index = 0
    result = []
    while True:
        # binary search [l, r)
        left = 1
        right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
        while right - left > 1:
            mid = (left + right) // 2
            if ffd_check(lengths[start_index : start_index + mid], c, n):
                left = mid
            else:
                right = mid
        # use length l
        batch = ffd_with_result(
            lengths[start_index : start_index + left], c, start_index
        )
        assert len(batch) <= n
        if len(batch) < n:
            break
        start_index += left
        s = lengths_cumsum[start_index - 1]
        # add local rank
        result.append(batch[rank])
    return result, s, len(result) * c * n
class MultipackBatchSampler(BatchSampler):
    """
    Batch Sampler class for multipack
    """
    def __init__(
        self,
        sampler: Union[Sampler[int], Iterable[int]],
        batch_size: int,
        drop_last: bool,
        batch_max_len: int,
        lengths: np.ndarray,
        packing_efficiency_estimate: float = 1.0,
    ):
        super().__init__(sampler, batch_size, drop_last)
        self.batch_size = None
        self.batch_max_len = batch_max_len
        self.lengths: np.ndarray = lengths
        self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
        assert isinstance(self.lengths, np.ndarray)
        self.epoch = 0
        # statistics
        self.eff_total_used = 0
        self.eff_total_slots = 0
    def set_epoch(self, epoch: int):
        self.epoch = epoch
    def generate_batches(self, set_stats=False):
        indices = [idx for idx in self.sampler]
        lengths = self.lengths[indices]
        lengths_cumsum = np.cumsum(lengths)
        batches, total_used, total_slots = allocate(
            lengths=lengths,
            lengths_cumsum=lengths_cumsum,
            rank=0,
            c=self.batch_max_len,
            n=1,
        )
        batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
        # statistics
        if set_stats:
            self.eff_total_used += total_used
            self.eff_total_slots += total_slots
        return batches
    def __iter__(self):
        batches = self.generate_batches(set_stats=True)
        return iter(batches)
    def num_batches(self):
        batches = self.generate_batches(set_stats=True)
        return len(batches)
    def efficiency(self):
        return self.eff_total_used / self.eff_total_slots
    def __len__(self):
        self.num_batches()
        return self._len_est()
    def _len_est(self):
        world_size = int(os.getenv("WORLD_SIZE", "1"))
        lengths_sum = np.sum(self.lengths)
        lengths_sum_per_device = lengths_sum // world_size
        LOG.info(
            f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
            f"total_num_tokens per device: {lengths_sum_per_device}"
        )
        # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
        return max(
            0,
            (
                world_size
                * math.floor(
                    0.99
                    * lengths_sum_per_device
                    / self.packing_efficiency_estimate
                    // self.batch_max_len
                )
                - 1
            ),
        )
 |