brianling16 commited on
Commit
4151f9d
·
verified ·
1 Parent(s): a2ef3ab

Delete shared_attention.py

Browse files
Files changed (1) hide show
  1. shared_attention.py +0 -142
shared_attention.py DELETED
@@ -1,142 +0,0 @@
1
- import copy
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import math
6
- from typing import List, Optional
7
-
8
- from transformer import MultiheadSelfAttention, MLP, TransformerLayer
9
- from lora_layer import LoRALinear, LoRAAdapter, LoRAConv1D
10
-
11
- class SharedAttention(nn.Module):
12
- def __init__(self, base_attn, num_repeats: int, lora_rank: int, lora_alpha: float):
13
- super().__init__()
14
- self.n_heads = base_attn.n_heads
15
- self.d_head = base_attn.d_head
16
- self.d_model = base_attn.d_model
17
-
18
- self.q_proj = LoRALinear(base_attn.q_proj, lora_rank, lora_alpha, num_repeats)
19
- self.k_proj = LoRALinear(base_attn.k_proj, lora_rank, lora_alpha, num_repeats)
20
- self.v_proj = LoRALinear(base_attn.v_proj, lora_rank, lora_alpha, num_repeats)
21
- self.out_proj = LoRALinear(base_attn.out_proj, lora_rank, lora_alpha, num_repeats)
22
-
23
- def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
24
- B, T, C = x.shape
25
- H, D = self.n_heads, self.d_head
26
-
27
- q = self.q_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
28
- k = self.k_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
29
- v = self.v_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
30
-
31
- att = (q @ k.transpose(-2, -1)) / math.sqrt(D)
32
- if attn_mask is not None:
33
- att = att + attn_mask
34
- att = F.softmax(att, dim=-1)
35
- y = att @ v
36
- y = y.transpose(1,2).contiguous().view(B, T, C)
37
- return self.out_proj(y, repeat_idx)
38
-
39
- class SharedMLP(nn.Module):
40
- def __init__(self, base_mlp, num_repeats: int, lora_rank: int, lora_alpha: float):
41
- super().__init__()
42
- self.fc1 = LoRALinear(base_mlp.fc1, lora_rank, lora_alpha, num_repeats)
43
- self.fc2 = LoRALinear(base_mlp.fc2, lora_rank, lora_alpha, num_repeats)
44
- self.act = base_mlp.act
45
-
46
- def forward(self, x, repeat_idx: int):
47
- return self.fc2(self.act(self.fc1(x, repeat_idx)), repeat_idx)
48
-
49
- class SharedTransformerLayer(nn.Module):
50
- def __init__(self, base_layer, num_repeats: int, lora_rank: int, lora_alpha: float):
51
- super().__init__()
52
- self.ln1 = base_layer.ln1
53
- self.ln2 = base_layer.ln2
54
- self.dropout1 = base_layer.dropout1
55
- self.dropout2 = base_layer.dropout2
56
- self.attn = SharedAttention(base_layer.attn, num_repeats, lora_rank, lora_alpha)
57
- self.mlp = SharedMLP(base_layer.mlp, num_repeats, lora_rank, lora_alpha)
58
-
59
- def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
60
- y = self.attn(self.ln1(x), repeat_idx, attn_mask)
61
- x = x + self.dropout1(y)
62
- y = self.mlp(self.ln2(x), repeat_idx)
63
- x = x + self.dropout2(y)
64
- return x
65
-
66
- # ---- Conversion Utilities ----
67
- def average_weights(layers, attr):
68
- weights = [getattr(layer, attr).weight.data for layer in layers]
69
- return torch.stack(weights, dim=0).mean(dim=0)
70
-
71
-
72
- def initialize_lora_with_svd(lora_layer, original_weights, repeat_indices, rank):
73
- """
74
- original_weights: list of original weights for each repeat index
75
- repeat_indices: which repeat indices these weights correspond to
76
- """
77
- shared_weight = lora_layer.base_layer.weight.data.clone()
78
-
79
- for idx, orig_weight in zip(repeat_indices, original_weights):
80
- residual = orig_weight - shared_weight
81
- U, S, Vh = torch.linalg.svd(residual, full_matrices=False)
82
-
83
- # Truncate to rank
84
- U = U[:, :rank]
85
- S = S[:rank]
86
- Vh = Vh[:rank, :]
87
-
88
- # Initialize LoRA weights
89
- lora_layer.lora_A[idx].weight.data = Vh # A = Vᵣᵀ
90
- lora_layer.lora_B[idx].weight.data = U @ torch.diag(S) # B = UᵣΣᵣ
91
-
92
- def convert_to_recursive(model, K=2, rank=8, lora_alpha=1.0):
93
- n_layers = len(model.transformer.h)
94
- new_blocks = []
95
-
96
- for b in range(n_layers // K):
97
- block_layers = model.transformer.h[b*K:(b+1)*K]
98
- base_layer = copy.deepcopy(block_layers[0])
99
-
100
- # Average weights across the block for shared parameters
101
- with torch.no_grad():
102
- if hasattr(base_layer.attn, 'c_attn'):
103
- shared_weight = average_weights([l.attn for l in block_layers], 'c_attn')
104
- base_layer.attn.c_attn.weight.data = shared_weight
105
-
106
- if hasattr(base_layer.attn, 'c_proj'):
107
- shared_weight = average_weights([l.attn for l in block_layers], 'c_proj')
108
- base_layer.attn.c_proj.weight.data = shared_weight
109
-
110
- if hasattr(base_layer.mlp, 'c_fc'):
111
- shared_weight = average_weights([l.mlp for l in block_layers], 'c_fc')
112
- base_layer.mlp.c_fc.weight.data = shared_weight
113
-
114
- if hasattr(base_layer.mlp, 'c_proj'):
115
- shared_weight = average_weights([l.mlp for l in block_layers], 'c_proj')
116
- base_layer.mlp.c_proj.weight.data = shared_weight
117
-
118
- # Convert to LoRA
119
- if hasattr(base_layer.attn, 'c_attn'):
120
- base_layer.attn.c_attn = LoRAConv1D(
121
- base_layer.attn.c_attn, rank, lora_alpha, K
122
- )
123
-
124
- if hasattr(base_layer.attn, 'c_proj'):
125
- base_layer.attn.c_proj = LoRAConv1D(
126
- base_layer.attn.c_proj, rank, lora_alpha, K
127
- )
128
-
129
- if hasattr(base_layer.mlp, 'c_fc'):
130
- base_layer.mlp.c_fc = LoRAConv1D(
131
- base_layer.mlp.c_fc, rank, lora_alpha, K
132
- )
133
-
134
- if hasattr(base_layer.mlp, 'c_proj'):
135
- base_layer.mlp.c_proj = LoRAConv1D(
136
- base_layer.mlp.c_proj, rank, lora_alpha, K
137
- )
138
-
139
- new_blocks.append(base_layer)
140
-
141
- model.transformer.h = nn.ModuleList(new_blocks)
142
- return model