from typing import Any, Optional, Callable, List, Tuple import os import time import numpy as np import torch from torch import nn from torch.nn import functional as F from accelerate import init_empty_weights from transformers.activations import ACT2FN from transformers.generation import GenerationConfig from transformers.models.opt.modeling_opt import ( OPTAttention, OPTDecoder, OPTDecoderLayer, OPTForCausalLM, OPTModel, ) from transformers.models.opt.configuration_opt import OPTConfig from huggingface_hub import snapshot_download from configuration_tricksy import TricksyConfig from util import batch_copy, compute_index_diffs, load_mlp_sparsity_predictor, mmap_to_tensor, topk_and_threshold TRICKSY_WEIGHTS_PATH = 'tricksy-weights/' class SparseMLPCache: def __init__( self, indexed_fc1_weight: Optional[torch.Tensor] = None, indexed_fc1_bias: Optional[torch.Tensor] = None, indexed_fc2_weight: Optional[torch.Tensor] = None, gpu_cached_mlp_indices: Optional[torch.Tensor] = None, ): # [ffn_embed_dim * min_mlp_sparsity, hidden_size] self.indexed_fc1_weight = indexed_fc1_weight # [ffn_embed_dim * min_mlp_sparsity] self.indexed_fc1_bias = indexed_fc1_bias # [ffn_embed_dim * min_mlp_sparsity, hidden_size] (stored in transpose for efficient indexing) self.indexed_fc2_weight = indexed_fc2_weight # Indices that are already on GPU (this tensor is stored on the CPU) # [ffn_embed_dim * min_mlp_sparsity] self.gpu_cached_mlp_indices = gpu_cached_mlp_indices class SparseIndices: def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig): self.mlp_indices_buffer_gpu = torch.empty( (int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),), dtype=torch.int32, device='cuda' ) self.mlp_indices_buffer_cpu = torch.empty( (int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),), dtype=torch.int32, device='cpu', pin_memory=True, ) # Default stream blocks until indices are copied to CPU self.index_copy_stream = torch.cuda.default_stream() def copy_mlp_indices_to_cpu(self): self.mlp_indices_buffer_cpu = batch_copy([self.mlp_indices_buffer_gpu], self.index_copy_stream, device='cpu')[0] class OPTDiskWeights: def __init__(self, model_name: str): self.model_name = model_name self.model_suffix = model_name.split('/')[-1] self.config = OPTConfig.from_pretrained(model_name) try: print(f'downloading from austinsilveria/tricksy-{self.model_suffix}') self.weight_path = snapshot_download(repo_id=f'austinsilveria/tricksy-{self.model_suffix}') + '/' except: print(f'failed to download from austinsilveria/tricksy-{self.model_suffix}') self.weight_path = f'{TRICKSY_WEIGHTS_PATH}{self.model_suffix}/' with init_empty_weights(): model = OPTModel(self.config) self.state_dict = model.state_dict() if not os.path.exists(f'{self.weight_path}decoder.embed_tokens.weight'): # Download original weights and write memmap files print(f'downloading and preprocessing original weights') self.cache_weights() head_dim = self.config.hidden_size // self.config.num_attention_heads for i in range(self.config.num_hidden_layers): layer_prefix = f'decoder.layers.{i}.' self.delete_weights([ f'{layer_prefix}self_attn.q_proj.weight', f'{layer_prefix}self_attn.k_proj.weight', f'{layer_prefix}self_attn.v_proj.weight', f'{layer_prefix}self_attn.out_proj.weight', f'{layer_prefix}self_attn.q_proj.bias', f'{layer_prefix}self_attn.k_proj.bias', f'{layer_prefix}self_attn.v_proj.bias' ]) self.add_weights([ (f'{layer_prefix}fc2.weight', (self.config.ffn_dim, self.config.hidden_size)), (f'{layer_prefix}self_attn.catted_head_weights', (self.config.num_attention_heads, head_dim * 4, self.config.hidden_size)), (f'{layer_prefix}self_attn.catted_head_biases', (self.config.num_attention_heads, 3, head_dim)), ]) self.memmap_weights = { key: self.load_memmap_weight(key) for key in self.state_dict.keys() } def load_memmap_weight(self, key: str): return torch.from_numpy(np.memmap(f'{self.weight_path}{key}', dtype='float16', mode='r', shape=(self.state_dict[key].shape))) def add_weights(self, weights: List[Tuple[str, torch.Size]]): for key, shape in weights: self.state_dict[key] = torch.empty(shape, dtype=torch.float16, device='meta') def delete_weights(self, keys: List[str]): for key in keys: if key in self.state_dict: del self.state_dict[key] path = f'{self.weight_path}{key}' if os.path.exists(path): os.remove(path) def cache_weights(self): os.makedirs(self.weight_path, exist_ok=True) weights_location = snapshot_download(repo_id=self.model_name, ignore_patterns=['flax*', 'tf*']) shards = [file for file in os.listdir(weights_location) if file.startswith("pytorch_model") and file.endswith(".bin")] for shard in shards: print(f'caching {shard}') shard_path = os.path.join(weights_location, shard) shard_state_dict = torch.load(shard_path) for key in shard_state_dict.keys(): path = f'{self.weight_path}{key.replace("model.", "")}' memmap = np.memmap(path, dtype='float16', mode='w+', shape=(shard_state_dict[key].shape)) memmap[:] = shard_state_dict[key].cpu().numpy() # Store weights in shape for efficient indexing for i in range(self.config.num_hidden_layers): layer_prefix = f'decoder.layers.{i}.' # FC2 in transpose fc2t = torch.from_numpy(np.array(self.load_memmap_weight(f'{layer_prefix}fc2.weight')[:])).t().contiguous().clone() np.memmap(f'{self.weight_path}decoder.layers.{i}.fc2.weight', dtype='float16', mode='w+', shape=fc2t.shape)[:] = fc2t.numpy() # Attention weights by head head_dim = self.config.hidden_size // self.config.num_attention_heads qw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.weight')[:]) kw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.weight')[:]) vw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.weight')[:]) ow = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.out_proj.weight')[:]) pre_cat_shape = (self.config.num_attention_heads, head_dim, self.config.hidden_size) # [head, head_dim * 4, hidden_size] catted_head_weights = torch.cat( [qw.view(pre_cat_shape).clone(), kw.view(pre_cat_shape).clone(), vw.view(pre_cat_shape).clone(), ow.T.view(pre_cat_shape).clone(),], dim=1, ).contiguous().clone() np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_weights', dtype='float16', mode='w+', shape=catted_head_weights.shape)[:] =\ catted_head_weights.numpy() # Attention biases by head qb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.bias')[:]) kb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.bias')[:]) vb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.bias')[:]) pre_cat_shape = (self.config.num_attention_heads, 1, head_dim) # [head, 3, head_dim] catted_head_biases = torch.cat( # Don't index out bias since we need all dims after projecting back up to hidden size [qb.view(pre_cat_shape).clone(), kb.view(pre_cat_shape).clone(), vb.view(pre_cat_shape).clone()], dim=1, ).contiguous().clone() np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_biases', dtype='float16', mode='w+', shape=catted_head_biases.shape)[:] =\ catted_head_biases.numpy() self.delete_weights([ f'{layer_prefix}self_attn.q_proj.weight', f'{layer_prefix}self_attn.k_proj.weight', f'{layer_prefix}self_attn.v_proj.weight', f'{layer_prefix}self_attn.out_proj.weight', f'{layer_prefix}self_attn.q_proj.bias', f'{layer_prefix}self_attn.k_proj.bias', f'{layer_prefix}self_attn.v_proj.bias' ]) self.add_weights([ (f'{layer_prefix}self_attn.catted_head_weights', catted_head_weights.shape), (f'{layer_prefix}self_attn.catted_head_biases', catted_head_biases.shape), ]) class TricksyContext: def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig): self.indices = SparseIndices(tricksy_config, opt_config) self.load_weight_stream = torch.cuda.Stream() self.layer = 0 self.is_prompt_phase = True self.forward_times = [] class TricksyLayer: def __call__(self, *args: Any, **kwds: Any) -> Any: return self.forward(*args, **kwds) def load_weights(self, tricksy_context: TricksyContext): pass class TricksyLayerInputs: def __init__( self, disk_weights: OPTDiskWeights, layer_key_prefix: str = None, next_layer: TricksyLayer = None, sparsity_predictors: List[Callable[[torch.Tensor], torch.Tensor]] = None, ) -> None: self.disk_weights = disk_weights # self.get_weight = lambda key: self.disk_weights.load_memmap_weight(f'{layer_key_prefix}{key}') self.get_weight = lambda key: self.disk_weights.memmap_weights[(f'{layer_key_prefix}{key}')] self.layer_key_prefix = layer_key_prefix self.next_layer = next_layer self.sparsity_predictors = sparsity_predictors class TricksyOPTLearnedPositionalEmbedding(TricksyLayer): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, tricksy_context): # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 self.tricksy_context = tricksy_context self.weight = None def __call__(self, *args: Any, **kwds: Any) -> Any: return self.forward(*args, **kwds) def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long() # create positions depending on attention_mask positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 # cut positions if `past_key_values_length` is > 0 positions = positions[:, past_key_values_length:] out = F.embedding(positions + self.offset, self.weight) return out class TricksyOPTAttention(OPTAttention, TricksyLayer): def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext, is_decoder: bool = False, **kwargs): nn.Module.__init__(self) self.tricksy_config = tricksy_config self.config = tricksy_config.opt_config def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs): """ If a the deprecated argument `fn_arg_name` is passed, raise a deprecation warning and return that value, otherwise take the equivalent config.config_arg_name """ val = None if fn_arg_name in kwargs: print( "Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38." " Please set it in the config instead" ) val = kwargs.pop(fn_arg_name) else: val = getattr(config, config_arg_name) return val self.embed_dim = _handle_deprecated_argument("hidden_size", self.config, "embed_dim", kwargs) self.num_heads = _handle_deprecated_argument("num_attention_heads", self.config, "num_heads", kwargs) self.dropout = _handle_deprecated_argument("attention_dropout", self.config, "dropout", kwargs) self.enable_bias = _handle_deprecated_argument("enable_bias", self.config, "bias", kwargs) self.head_dim = self.embed_dim // self.num_heads self.is_causal = True if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder # [Tricksy] self.tricksy_context = tricksy_context self.inputs = inputs self.head_dim = self.config.hidden_size // self.config.num_attention_heads self.qw = self.kw = self.vw = self.ow = self.qb = self.kb = self.vb = self.out_proj_bias = self.layer_norm_weight = self.layer_norm_bias = None self.q_proj = lambda x: F.linear(x, self.qw, self.qb) self.k_proj = lambda x: F.linear(x, self.kw, self.kb) self.v_proj = lambda x: F.linear(x, self.vw, self.vb) self.out_proj = lambda x: F.linear(x, self.ow, self.out_proj_bias) self.layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.layer_norm_weight, self.layer_norm_bias) def clear(self): self.qw = self.kw = self.vw = self.ow = self.qb = self.kb = self.vb = self.out_proj_bias = self.layer_norm_weight = self.layer_norm_bias = None def load_weights(self, tricksy_context: TricksyContext): if self.tricksy_context.is_prompt_phase: # Full weights for prompt phase self.catted_weights, self.catted_biases, self.out_proj_bias, self.layer_norm_weight, self.layer_norm_bias = batch_copy( [ mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_weights')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_biases')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('self_attn.out_proj.bias')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.weight')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.bias')[:], pin_memory=True), ], tricksy_context.load_weight_stream, ) torch.cuda.synchronize() # Weights stored in shape for efficient indexing to support offloading attention heads (not currently being done) self.qw = self.catted_weights[:, :self.head_dim, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous() self.kw = self.catted_weights[:, self.head_dim:self.head_dim * 2, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous() self.vw = self.catted_weights[:, self.head_dim * 2:self.head_dim * 3, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous() self.ow = self.catted_weights[:, self.head_dim * 3:, :].reshape(self.config.hidden_size, self.config.hidden_size).t().contiguous() self.catted_weights = None self.qb = self.catted_biases[:, 0, :].reshape(self.config.hidden_size).contiguous() self.kb = self.catted_biases[:, 1, :].reshape(self.config.hidden_size).contiguous() self.vb = self.catted_biases[:, 2, :].reshape(self.config.hidden_size).contiguous() self.catted_biases = None def forward(self, hidden_states, **kwargs): # Wait for attention weights to get to GPU torch.cuda.synchronize() # Predict MLP sparsity based on attention input self.tricksy_context.indices.mlp_indices_buffer_gpu = topk_and_threshold( self.inputs.sparsity_predictors[0](hidden_states)[0, -1, :], int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu), ) self.tricksy_context.indices.copy_mlp_indices_to_cpu() torch.cuda.synchronize() # Load MLP weights while computing attention self.inputs.next_layer.load_weights(self.tricksy_context) out = super().forward(self.layer_norm(hidden_states), **kwargs) # Wait for MLP weights to get to GPU torch.cuda.synchronize() return out class TricksyOPTDecoderLayer(OPTDecoderLayer): def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext): nn.Module.__init__(self) self.tricksy_config = tricksy_config self.config = tricksy_config.opt_config self.embed_dim = self.config.hidden_size self.tricksy_context = tricksy_context self.self_attn_layer_inputs = TricksyLayerInputs( disk_weights=inputs.disk_weights, layer_key_prefix=inputs.layer_key_prefix, # While computing attention, load MLP next_layer=self, sparsity_predictors=inputs.sparsity_predictors, ) self.self_attn = TricksyOPTAttention(tricksy_config, self.self_attn_layer_inputs, tricksy_context, is_decoder=True) self.do_layer_norm_before = self.config.do_layer_norm_before self.dropout = self.config.dropout self.activation_fn = ACT2FN[self.config.activation_function] self.inputs = inputs random_mlp_indices_gpu =\ torch.randperm(self.config.ffn_dim, device='cpu', dtype=torch.int32)[:int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu)] self.index_cache = SparseMLPCache(gpu_cached_mlp_indices=random_mlp_indices_gpu) # identity since we move this to attention layer # extreme tricksy self.self_attn_layer_norm = lambda x: x self.fc1_weight = self.fc2_weight = self.final_layer_norm_weight = self.fc1_bias = self.fc2_bias = self.final_layer_norm_bias = None self.ring_idx = 0 self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None self.fc1 = lambda x: F.linear(x, torch.cat([self.fc1_weight, self.fc1_weight_diff]), torch.cat([self.fc1_bias, self.fc1_bias_diff])) self.fc2 = lambda x: F.linear(x, torch.cat([self.fc2_weight, self.fc2_weight_diff]).T, self.fc2_bias) self.final_layer_norm = lambda x: F.layer_norm(x, (self.embed_dim,), self.final_layer_norm_weight, self.final_layer_norm_bias) def clear(self): self.fc1_weight = self.fc2_weight = self.final_layer_norm_weight = self.fc1_bias = self.fc2_bias = self.final_layer_norm_bias = None self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None def load_weights(self, tricksy_context: TricksyContext): if self.tricksy_context.is_prompt_phase: # Full weights for prompt phase fc1w = mmap_to_tensor(self.inputs.get_weight('fc1.weight')[:], pin_memory=True) fc1b = mmap_to_tensor(self.inputs.get_weight('fc1.bias')[:], pin_memory=True) fc2w = mmap_to_tensor(self.inputs.get_weight('fc2.weight')[:], pin_memory=True) fc2b = mmap_to_tensor(self.inputs.get_weight('fc2.bias')[:], pin_memory=True) lnw = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True) lnb = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True) self.fc1_weight, self.fc1_bias, self.fc2_weight, self.fc2_bias, self.final_layer_norm_weight, self.final_layer_norm_bias =\ batch_copy([fc1w, fc1b, fc2w, fc2b, lnw, lnb], tricksy_context.load_weight_stream) self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') index_diffs = compute_index_diffs(tricksy_context.indices.mlp_indices_buffer_cpu, [self.index_cache.gpu_cached_mlp_indices]) if len(index_diffs) > 0: gpu_index_diff = index_diffs[0] self.index_cache.gpu_cached_mlp_indices[gpu_index_diff.off_positions] = gpu_index_diff.off_elements self.index_cache.indexed_fc1_weight = fc1w.contiguous().pin_memory() self.index_cache.indexed_fc1_bias = fc1b.contiguous().pin_memory() self.index_cache.indexed_fc2_weight = fc2w.contiguous().pin_memory() return elif self.fc1_weight is None: # Full weights if full offload self.fc1_weight, self.fc1_bias, self.fc2_weight = batch_copy( [self.index_cache.indexed_fc1_weight, self.index_cache.indexed_fc1_bias, self.index_cache.indexed_fc2_weight], tricksy_context.load_weight_stream ) self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') off_elements = torch.tensor( list(set(tricksy_context.indices.mlp_indices_buffer_cpu.tolist()).difference(set(self.index_cache.gpu_cached_mlp_indices.tolist()))), device='cpu', dtype=torch.int32, pin_memory=True ) if off_elements.size(0) == 0: self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') return new_ring_idx = (self.ring_idx + off_elements.size(0)) % self.index_cache.gpu_cached_mlp_indices.size(0) if new_ring_idx > self.ring_idx: # single contiguous update self.index_cache.gpu_cached_mlp_indices[self.ring_idx:new_ring_idx] = off_elements elif off_elements.size(0) > 0: split = self.index_cache.gpu_cached_mlp_indices.size(0) - self.ring_idx # end of ring self.index_cache.gpu_cached_mlp_indices[self.ring_idx:] = off_elements[:split] # beginning of ring self.index_cache.gpu_cached_mlp_indices[:new_ring_idx] = off_elements[split:] # Allocate self.fc1_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda') self.fc1_bias_diff = torch.empty((off_elements.size(0)), dtype=self.tricksy_config.dtype, device='cuda') self.fc2_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda') # Index fc1wd = self.index_cache.indexed_fc1_weight[off_elements].pin_memory() fc1bd = self.index_cache.indexed_fc1_bias[off_elements].pin_memory() fc2wd = self.index_cache.indexed_fc2_weight[off_elements].pin_memory() # Copy self.fc1_weight_diff, self.fc1_bias_diff, self.fc2_weight_diff = batch_copy([fc1wd, fc1bd, fc2wd], tricksy_context.load_weight_stream) def forward(self, *args, **kwargs): # Wait for attention weights to get to GPU torch.cuda.synchronize() # Load next layer's attention weights self.inputs.next_layer.load_weights(self.tricksy_context) out = super().forward(*args, **kwargs) if self.tricksy_config.full_offload: self.fc1_weight = self.fc1_bias = self.fc2_weight = None elif self.tricksy_context.is_prompt_phase: # Only keep sparse MLP weights on GPU after prompt phase self.fc1_weight = self.fc1_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')] self.fc1_bias = self.fc1_bias[self.index_cache.gpu_cached_mlp_indices.to('cuda')] self.fc2_weight = self.fc2_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')] # Update ring buffers if not self.tricksy_config.full_offload: prev_ring_idx = self.ring_idx self.ring_idx = (self.ring_idx + self.fc1_weight_diff.size(0)) % self.fc1_weight.size(0) if self.ring_idx > prev_ring_idx: # does not wrap around ring self.fc1_weight[prev_ring_idx:self.ring_idx] = self.fc1_weight_diff self.fc1_bias[prev_ring_idx:self.ring_idx] = self.fc1_bias_diff self.fc2_weight[prev_ring_idx:self.ring_idx] = self.fc2_weight_diff elif self.fc1_weight_diff.size(0) > 0: # wraps around ring split = self.fc1_weight_diff.size(0) - self.ring_idx self.fc1_weight[prev_ring_idx:] = self.fc1_weight_diff[:split] self.fc1_weight[:self.ring_idx] = self.fc1_weight_diff[split:] self.fc1_bias[prev_ring_idx:] = self.fc1_bias_diff[:split] self.fc1_bias[:self.ring_idx] = self.fc1_bias_diff[split:] self.fc2_weight[prev_ring_idx:] = self.fc2_weight_diff[:split] self.fc2_weight[:self.ring_idx] = self.fc2_weight_diff[split:] self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None self.tricksy_context.layer += 1 return out class TricksyOPTDecoder(OPTDecoder, TricksyLayer): def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext): nn.Module.__init__(self) self.config = tricksy_config.opt_config self.dropout = self.config.dropout self.layerdrop = self.config.layerdrop self.padding_idx = self.config.pad_token_id self.max_target_positions = self.config.max_position_embeddings self.vocab_size = self.config.vocab_size self._use_flash_attention_2 = False self.gradient_checkpointing = False self.project_out = None self.project_in = None self.embed_tokens_weight = None self.embed_positions = TricksyOPTLearnedPositionalEmbedding(tricksy_context) self.tricksy_context = tricksy_context self.layers: List[TricksyOPTDecoderLayer] = [] for i in range(self.config.num_hidden_layers): pretrained_layer_num = self.config.num_hidden_layers - i - 1 sparsity_predictors = [load_mlp_sparsity_predictor(disk_weights.weight_path, pretrained_layer_num, tricksy_config.dtype)] if sparsity_predictors[0] is None: sparsity_predictors[0] = lambda x: F.linear(x, torch.rand((self.config.ffn_dim, self.config.hidden_size), device='cuda', dtype=tricksy_config.dtype)) self.layers.append(TricksyOPTDecoderLayer( tricksy_config, TricksyLayerInputs( disk_weights=disk_weights, layer_key_prefix=f'decoder.layers.{pretrained_layer_num}.', # While computing MLP, load next attention # While computing last MLP, load output embeddings (stored in TricksyOPTForCausalLM) next_layer=self.layers[i - 1].self_attn if i > 0 else tricksy_opt_for_causal_lm, sparsity_predictors=sparsity_predictors, ), tricksy_context, )) self.layers.reverse() self.final_layer_norm = lambda x: x self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.') def clear(self): self.embed_tokens_weight = self.embed_positions.weight = None for layer in self.layers: layer.clear() def embed_tokens(self, x): return F.embedding(x, self.embed_tokens_weight, self.padding_idx) def load_weights(self, tricksy_context: TricksyContext): if self.embed_tokens_weight is None: self.embed_tokens_weight, self.embed_positions.weight = batch_copy( [ mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('embed_positions.weight')[:], pin_memory=True), ], tricksy_context.load_weight_stream, ) def forward(self, *args, **kwargs): # Wait for input embedding weights to get to GPU torch.cuda.synchronize() # While computing input embeddings, load first attention self.layers[0].self_attn.load_weights(self.tricksy_context) out = super().forward(*args, **kwargs) # Wait for output embedding weights to get to GPU torch.cuda.synchronize() # No longer prompt phase after first full pass self.tricksy_context.is_prompt_phase = False # Load input embeddings while computing output self.load_weights(self.tricksy_context) return out class TricksyOPTModel(OPTModel): def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext): nn.Module.__init__(self) self.config = tricksy_config.opt_config self.tricksy_context = tricksy_context self.decoder = TricksyOPTDecoder(tricksy_config, disk_weights, tricksy_opt_for_causal_lm, tricksy_context) def clear(self): self.decoder.clear() def forward(self, *args, **kwargs): out = super().forward(*args, **kwargs) return out # who's got the weights? # [InputEmbedding, Attention.0, MLP.0, Attention.1, MLP.1, ..., OutputEmbedding] # [TricksyOPTDecoder, TricksyOPTAttention.0, TricksyOPTDecoderLayer.0, TricksyOPTAttention.1, TricksyDecoderLayer.1, ..., TricksyOPTForCausalLM] # # 1. Prompt pass: Before computing layer, send full dense weights to GPU. After computing layer, only keep sparse weights on GPU. # 2. Generation passes: Before computing layer, compute and send sparse weight diff to GPU. class TricksyOPTForCausalLM(OPTForCausalLM, TricksyLayer): def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights): nn.Module.__init__(self) self.config = disk_weights.config self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None self.tricksy_context = TricksyContext(tricksy_config, self.config) self.model = TricksyOPTModel(tricksy_config, disk_weights, self, self.tricksy_context) self.final_layer_norm_weight = self.lm_head_weight = self.final_layer_norm_bias = None # double stacking tricksy! self.final_layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.final_layer_norm_weight, self.final_layer_norm_bias) self.lm_head = lambda x: F.linear(self.final_layer_norm(x), self.lm_head_weight) self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.', next_layer=self.model.decoder) def clear(self): self.model.clear() def load_weights(self, tricksy_context: TricksyContext): if self.final_layer_norm_weight is None: self.final_layer_norm_weight, self.lm_head_weight, self.final_layer_norm_bias = batch_copy( [ mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True), mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True), ], tricksy_context.load_weight_stream, ) def forward(self, *args, **kwargs): torch.cuda.synchronize() start = time.time() out = super().forward(*args, **kwargs) torch.cuda.synchronize() self.tricksy_context.forward_times.append(time.time() - start) self.tricksy_context.layer = 0 return out def generate(self, *args, **kwargs): # Load input embeddings for first token self.model.decoder.load_weights(self.tricksy_context) torch.cuda.synchronize() out = super().generate(*args, **kwargs) return out