File size: 1,273 Bytes
8560fdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import torch
from transformers import LlamaForCausalLM
from .configuration_pruned_llama import LlamaPrunedConfig
import torch.nn as nn


class LlamaPrunedForCausalLM(LlamaForCausalLM):
    config_class = LlamaPrunedConfig

    def __init__(self, config: LlamaPrunedConfig):
        super().__init__(config)

        for layer in self.model.layers[config.begin_pruned_layer: config.end_pruned_layer]:
            layer.self_attn.hidden_size = 3072
            layer.self_attn.q_proj = nn.Linear(4096, 3072, bias=False)
            layer.self_attn.k_proj = nn.Linear(4096, 768, bias=False)
            layer.self_attn.v_proj = nn.Linear(4096, 768, bias=False)
            layer.self_attn.o_proj = nn.Linear(3072, 4096, bias=False)
            layer.mlp.gate_proj = nn.Linear(4096, 10752, bias=False)
            layer.mlp.up_proj = nn.Linear(4096, 10752, bias=False)
            layer.mlp.down_proj = nn.Linear(10752, 4096, bias=False)
        
        for layer in self.model.layers:
            layer.self_attn.num_heads = layer.self_attn.q_proj.weight.data.shape[0] // layer.self_attn.head_dim
            layer.self_attn.num_key_value_heads = layer.self_attn.k_proj.weight.data.shape[
                                                      0] // layer.self_attn.head_dim