| import contextlib |
| import dataclasses |
| from collections import defaultdict |
| from typing import DefaultDict, Dict |
| from pipeline import are_two_tensors_similar |
| import torch |
|
|
|
|
|
|
| @dataclasses.dataclass |
| class CacheContext: |
| buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) |
| incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) |
|
|
| def get_incremental_name(self, name=None): |
| if name is None: |
| name = "default" |
| idx = self.incremental_name_counters[name] |
| self.incremental_name_counters[name] += 1 |
| return f"{name}_{idx}" |
|
|
| def reset_incremental_names(self): |
| self.incremental_name_counters.clear() |
|
|
| @torch.compiler.disable |
| def get_buffer(self, name): |
| return self.buffers.get(name) |
|
|
| @torch.compiler.disable |
| def set_buffer(self, name, buffer): |
| self.buffers[name] = buffer |
|
|
| def clear_buffers(self): |
| self.buffers.clear() |
|
|
|
|
| @torch.compiler.disable |
| def get_buffer(name): |
| cache_context = get_current_cache_context() |
| assert cache_context is not None, "cache_context must be set before" |
| return cache_context.get_buffer(name) |
|
|
|
|
| @torch.compiler.disable |
| def set_buffer(name, buffer): |
| cache_context = get_current_cache_context() |
| assert cache_context is not None, "cache_context must be set before" |
| cache_context.set_buffer(name, buffer) |
|
|
|
|
| _current_cache_context = None |
|
|
|
|
| def create_cache_context(): |
| return CacheContext() |
|
|
|
|
| def get_current_cache_context(): |
| return _current_cache_context |
|
|
|
|
| def set_current_cache_context(cache_context=None): |
| global _current_cache_context |
| _current_cache_context = cache_context |
|
|
|
|
| @contextlib.contextmanager |
| def cache_context(cache_context): |
| global _current_cache_context |
| old_cache_context = _current_cache_context |
| _current_cache_context = cache_context |
| try: |
| yield |
| finally: |
| _current_cache_context = old_cache_context |
|
|
|
|
| @torch.compiler.disable |
| def are_two_tensors_similar_old(t1, t2, *, threshold, parallelized=False): |
| mean_diff = (t1 - t2).abs().mean() |
| mean_t1 = t1.abs().mean() |
| diff = mean_diff / mean_t1 |
| return diff.item() < threshold |
|
|
|
|
| @torch.compiler.disable |
| def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states): |
| hidden_states_residual = get_buffer("hidden_states_residual") |
| assert hidden_states_residual is not None, "hidden_states_residual must be set before" |
| hidden_states = hidden_states_residual + hidden_states |
|
|
| encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual") |
| assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before" |
| encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states |
|
|
| hidden_states = hidden_states.contiguous() |
| encoder_hidden_states = encoder_hidden_states.contiguous() |
|
|
| return hidden_states, encoder_hidden_states |
|
|
|
|
| @torch.compiler.disable |
| def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False): |
| prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual") |
| can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar( |
| prev_first_hidden_states_residual, |
| first_hidden_states_residual, |
| threshold=threshold, |
| parallelized=parallelized, |
| ) |
| return can_use_cache |
|
|
|
|
| class CachedTransformerBlocks(torch.nn.Module): |
| def __init__( |
| self, |
| transformer_blocks, |
| single_transformer_blocks=None, |
| *, |
| transformer=None, |
| residual_diff_threshold, |
| return_hidden_states_first=True, |
| ): |
| super().__init__() |
| self.transformer = transformer |
| self.transformer_blocks = transformer_blocks |
| self.single_transformer_blocks = single_transformer_blocks |
| self.residual_diff_threshold = residual_diff_threshold |
| self.return_hidden_states_first = return_hidden_states_first |
|
|
| def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): |
| if self.residual_diff_threshold <= 0.0: |
| for block in self.transformer_blocks: |
| hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) |
| if not self.return_hidden_states_first: |
| hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states |
| if self.single_transformer_blocks is not None: |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
| for block in self.single_transformer_blocks: |
| hidden_states = block(hidden_states, *args, **kwargs) |
| hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :] |
| return ( |
| (hidden_states, encoder_hidden_states) |
| if self.return_hidden_states_first |
| else (encoder_hidden_states, hidden_states) |
| ) |
|
|
| original_hidden_states = hidden_states |
| first_transformer_block = self.transformer_blocks[0] |
| hidden_states, encoder_hidden_states = first_transformer_block( |
| hidden_states, encoder_hidden_states, *args, **kwargs |
| ) |
| if not self.return_hidden_states_first: |
| hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states |
| first_hidden_states_residual = hidden_states - original_hidden_states |
| del original_hidden_states |
|
|
| can_use_cache = get_can_use_cache( |
| first_hidden_states_residual, |
| threshold=self.residual_diff_threshold, |
| parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False), |
| ) |
|
|
| torch._dynamo.graph_break() |
| if can_use_cache: |
| del first_hidden_states_residual |
| hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual( |
| hidden_states, encoder_hidden_states |
| ) |
| else: |
| set_buffer("first_hidden_states_residual", first_hidden_states_residual) |
| del first_hidden_states_residual |
| ( |
| hidden_states, |
| encoder_hidden_states, |
| hidden_states_residual, |
| encoder_hidden_states_residual, |
| ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs) |
| set_buffer("hidden_states_residual", hidden_states_residual) |
| set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual) |
| torch._dynamo.graph_break() |
|
|
| return ( |
| (hidden_states, encoder_hidden_states) |
| if self.return_hidden_states_first |
| else (encoder_hidden_states, hidden_states) |
| ) |
|
|
| def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs): |
| original_hidden_states = hidden_states |
| original_encoder_hidden_states = encoder_hidden_states |
| for block in self.transformer_blocks[1:]: |
| hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) |
| if not self.return_hidden_states_first: |
| hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states |
| if self.single_transformer_blocks is not None: |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
| for block in self.single_transformer_blocks: |
| hidden_states = block(hidden_states, *args, **kwargs) |
| encoder_hidden_states, hidden_states = hidden_states.split( |
| [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 |
| ) |
|
|
| |
| |
| hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape) |
| encoder_hidden_states = ( |
| encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape) |
| ) |
|
|
| |
| |
|
|
| hidden_states_residual = hidden_states - original_hidden_states |
| encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states |
|
|
| hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape) |
| encoder_hidden_states_residual = ( |
| encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape) |
| ) |
|
|
| return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual |
|
|