File size: 2,651 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
"""
This file defines a mixin class for sparse transformers that enables elastic memory management.
It provides functionality to dynamically adjust memory usage by controlling gradient checkpointing
across transformer blocks, allowing for trading computation for memory efficiency.
"""
from contextlib import contextmanager
from typing import *
import math
from ..modules import sparse as sp
from ..utils.elastic_utils import ElasticModuleMixin
class SparseTransformerElasticMixin(ElasticModuleMixin):
"""
A mixin class for sparse transformers that provides elastic memory management capabilities.
Extends the base ElasticModuleMixin with sparse tensor-specific functionality.
"""
def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
"""
Determines the input size from a sparse tensor.
Args:
x: A SparseTensor input
*args, **kwargs: Additional arguments (unused)
Returns:
The size of the feature dimension of the sparse tensor
"""
return x.feats.shape[0]
@contextmanager
def with_mem_ratio(self, mem_ratio=1.0):
"""
Context manager that temporarily adjusts memory usage by enabling gradient checkpointing
for a portion of the transformer blocks based on the specified memory ratio.
Args:
mem_ratio: A value between 0 and 1 indicating the desired memory ratio.
1.0 means use all available memory (no checkpointing).
Lower values enable more checkpointing to reduce memory usage.
Yields:
The exact memory ratio that could be achieved with the block granularity.
"""
if mem_ratio == 1.0:
# No memory optimization needed if ratio is 1.0
yield 1.0
return
# Calculate how many blocks should use checkpointing
num_blocks = len(self.blocks)
num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
# Calculate the actual memory ratio based on the number of checkpointed blocks
exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
# Enable checkpointing for the calculated number of blocks
for i in range(num_blocks):
self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
yield exact_mem_ratio
# Restore all blocks to not use checkpointing after context exit
for i in range(num_blocks):
self.blocks[i].use_checkpoint = False
|