def memory_for_attention_layer(precession: int, seq_len: int, batch_size: int, hidden_size: int, num_heads: int): """ head_dim = hidden_size // num_heads Model Parameters: q_proj: (hidden_size, num_heads * head_dim) k_proj: (hidden_size, num_key_value_heads * head_dim) v_proj: (hidden_size, num_key_value_heads * head_dim) o_proj: (hidden_size, hidden_size) Total parameters = 3 * hidden_size * num_heads * head_dim + hidden_size^2 Memory required for model parameters = (3 * hidden_size * num_heads * head_dim + hidden_size^2) Gradients: Gradients have the same size as the model parameters. Memory required for gradients = (3 * hidden_size * num_heads * head_dim + hidden_size^2) Optimizer States: Assuming Adam optimizer with two states per parameter (momentum and variance). Memory required for optimizer states = 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) Activations: query_states: (batch_size, num_heads, q_len, head_dim) key_states: (batch_size, num_key_value_heads, q_len, head_dim) value_states: (batch_size, num_key_value_heads, q_len, head_dim) attn_weights: (batch_size, num_heads, q_len, q_len) attn_output: (batch_size, q_len, hidden_size) Total activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size) Memory required for activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size) Temporary Memory: Additional temporary memory for intermediate computations and buffer storage. Assuming 20% of the total memory as temporary memory. total_memory = (model_parameters + gradients + optimizer_states + activations) * (1 + temporary_memory_factor) ((3 * hidden_size * num_heads * head_dim + hidden_size^2) + (3 * hidden_size * num_heads * head_dim + hidden_size^2) + 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) + batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)) * (1 + 0.2) """ head_dim = hidden_size // num_heads # Model Memory (3 * hidden_size * num_heads * head_dim + hidden_size^2) model_memory = 3 * hidden_size * num_heads * head_dim + hidden_size ** 2 # Gradients = model_memory gradients = model_memory # Optimizer optimizer = 2 * model_memory # Activation activation = batch_size * (3 * num_heads * seq_len * head_dim + num_heads * seq_len ** 2 + seq_len * hidden_size ) total_memory = (model_memory + gradients + optimizer + activation) * precession return total_memory def memory_mlp_layer(precession: int, seq_len: int, batch_size: int, hidden_size: int, intermediate_size: int): """ MLP model gate_proj (hidden_size, intermediate_size) up_proj (hidden_size, intermediate_size) down_proj (intermediate_size, hidden_size) Memory required for gate_proj weights = intermediate_size * hidden_size Memory required for up_proj weights = intermediate_size * hidden_size Memory required for down_proj weights = intermediate_size * hidden_size model memory = 3 * (hidden_size * intermediate_size) gradient = model_memory optimizer = 2 * model_memory activations = batch_size * seq_len * hidden_size + 2 * batch_size * seq_len * intermediate_size total_memory = 3 * (hidden_size * intermediate_size) + 3 * (hidden_size * intermediate_size) + 6 * (hidden_size * intermediate_size) + batch_size * (2 * intermediate_size + hidden_size) total_memory = (hidden_size * intermediate_size) * 12 + Batch_size * seq_len * (2 * intermediate_size + hidden_size) Args: hidden_size: intermediate_size: batch_size: seq_len: Returns: """ model_memory = 3 * (hidden_size * intermediate_size) gradient = model_memory optimizer = 2 * model_memory activation = batch_size * seq_len * (2 * intermediate_size + hidden_size) total_memory = (model_memory + gradient + hidden_size + activation) * precession return total_memory def memory_moe_mlp(precession: int, seq_len: int, batch_size: int, hidden_size: int, intermediate_size: int, num_expert: int, top_k: int): # model memory gat_memory = hidden_size * num_expert # The result in byte moe_mlp = memory_mlp_layer(precession, seq_len, batch_size, hidden_size, intermediate_size) * num_expert # total model memory The result in byte model_memory = gat_memory * precession + moe_mlp # optimizer and gradient as before. # activation max_memory_activation = ( (batch_size * seq_len * num_expert * precession) + # Router logits (batch_size * seq_len * top_k * precession) + # Routing weights (batch_size * seq_len * top_k * precession) + # Selected experts (batch_size * seq_len * hidden_size * precession) + # Final hidden states (batch_size * seq_len * hidden_size * precession) + # Current state (worst-case) (batch_size * seq_len * hidden_size * precession) # Current hidden states (worst-case) ) total_memory = model_memory + model_memory + 2 * model_memory + max_memory_activation return total_memory