| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from .types import ModelInput |
|
|
|
|
| class StatefulBuffer: |
| """A buffer that stores model inputs.""" |
|
|
| def __init__(self, max_buffer_size: int = 1_000_000_000) -> None: |
| self._buffer: list[ModelInput] = [] |
| self._buffer_size: int = 0 |
| self._max_buffer_size: int = max_buffer_size |
|
|
| def __len__(self) -> int: |
| return len(self._buffer) |
|
|
| @property |
| def size(self) -> int: |
| return self._buffer_size |
|
|
| def put(self, samples: list[ModelInput]) -> None: |
| """Add samples to the buffer.""" |
| num_tokens = sum(len(sample["input_ids"]) for sample in samples) |
| if self._buffer_size + num_tokens > self._max_buffer_size: |
| raise ValueError(f"Buffer size exceeds max buffer size {self._max_buffer_size}.") |
|
|
| self._buffer.extend(samples) |
| self._buffer_size += num_tokens |
|
|
| def get(self, value: int) -> list[ModelInput]: |
| """Get samples from the buffer and remove them.""" |
| samples = self._buffer[:value] |
| self._buffer_size -= sum(len(sample["input_ids"]) for sample in samples) |
| del self._buffer[:value] |
| return samples |
|
|
| def clear(self) -> None: |
| """Clear the buffer.""" |
| self._buffer = [] |
| self._buffer_size = 0 |
|
|
| def state_dict(self) -> dict: |
| """Returns the state of the buffer.""" |
| return { |
| "buffer": self._buffer, |
| "buffer_size": self._buffer_size, |
| } |
|
|
| def load_state_dict(self, state_dict: dict) -> None: |
| """Loads the state into the buffer.""" |
| self._buffer = state_dict["buffer"] |
| self._buffer_size = state_dict["buffer_size"] |
|
|