brianling16 commited on
Commit
a2ef3ab
·
verified ·
1 Parent(s): 9cbbe9c

Delete lora_layer.py

Browse files
Files changed (1) hide show
  1. lora_layer.py +0 -139
lora_layer.py DELETED
@@ -1,139 +0,0 @@
1
- import copy
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import math
6
- from typing import Optional, List
7
-
8
- # ---- LoRA ----
9
- class LoRAAdapter(nn.Module):
10
- def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0,
11
- weight: Optional[torch.Tensor] = None):
12
- super().__init__()
13
- self.rank = rank
14
- self.alpha = alpha
15
- if rank > 0:
16
- self.A = nn.Parameter(torch.zeros((rank, in_features)))
17
- self.B = nn.Parameter(torch.zeros((out_features, rank)))
18
-
19
- # Initialize with SVD if base weight is provided
20
- if weight is not None:
21
- U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
22
- U = U[:, :rank]
23
- S = S[:rank]
24
- Vh = Vh[:rank, :]
25
- self.A.data = Vh # (rank, in_features)
26
- self.B.data = U @ torch.diag(S) # (out_features, rank)
27
- else:
28
- nn.init.normal_(self.A, std=1/rank)
29
- nn.init.zeros_(self.B)
30
- else:
31
- self.register_parameter('A', None)
32
- self.register_parameter('B', None)
33
-
34
- def delta(self) -> Optional[torch.Tensor]:
35
- if self.rank == 0 or self.A is None or self.B is None:
36
- return None
37
- return (self.B @ self.A) * (self.alpha / self.rank) # (out, in)
38
-
39
- def lora_parameters(self):
40
- if self.A is not None:
41
- yield self.A
42
- if self.B is not None:
43
- yield self.B
44
-
45
- class LoRALinear(nn.Module):
46
- def __init__(self, linear: nn.Linear, rank: int, alpha: float = 1.0, num_repeats: int = 1):
47
- super().__init__()
48
- self.linear = linear # base frozen linear
49
- self.rank = rank
50
- self.num_repeats = num_repeats
51
-
52
- if rank > 0:
53
- self.loras = nn.ModuleList([
54
- LoRAAdapter(linear.in_features, linear.out_features, rank, alpha)
55
- for _ in range(num_repeats)
56
- ])
57
- else:
58
- self.loras = nn.ModuleList([])
59
-
60
- def forward(self, x, repeat_idx: int = 0):
61
- out = self.linear(x) # [batch, ..., out_features]
62
- if self.rank == 0:
63
- return out
64
- delta = self.loras[repeat_idx].delta() # (out, in)
65
- if delta is not None:
66
- delta_t = delta # nn.Linear expects (out, in)
67
- return out + F.linear(x, delta_t)
68
- return out
69
-
70
- def lora_parameters(self):
71
- for lora in self.loras:
72
- yield from lora.lora_parameters()
73
-
74
-
75
- class LoRAConv1D(nn.Module):
76
- """GPT-2 style Conv1D with LoRA support."""
77
- def __init__(self, conv1d, rank: int, alpha: float = 1.0, num_repeats: int = 1):
78
- super().__init__()
79
- self.conv1d = conv1d # base GPT-2 Conv1D
80
- self.rank = rank
81
- self.num_repeats = num_repeats
82
- in_features, out_features = conv1d.weight.shape # GPT-2 Conv1D: [in, out]
83
-
84
- # Special handling for c_attn layer which has 3x output features
85
- self.is_c_attn = (out_features % 3 == 0) and ("c_attn" in str(conv1d))
86
- self.split_size = out_features // 3 if self.is_c_attn else out_features
87
-
88
- if rank > 0:
89
- if self.is_c_attn:
90
- # Create separate LoRA adapters for Q, K, V projections
91
- self.loras = nn.ModuleList([
92
- nn.ModuleList([
93
- LoRAAdapter(in_features, self.split_size, rank, alpha)
94
- for _ in range(3) # Q, K, V
95
- ]) for _ in range(num_repeats)
96
- ])
97
- else:
98
- self.loras = nn.ModuleList([
99
- LoRAAdapter(in_features, out_features, rank, alpha)
100
- for _ in range(num_repeats)
101
- ])
102
- else:
103
- self.loras = nn.ModuleList([])
104
-
105
- def forward(self, x, repeat_idx: int = 0):
106
- """
107
- x: [batch, seq_len, in_features]
108
- returns: [batch, seq_len, out_features]
109
- """
110
- out = self.conv1d(x)
111
- if self.rank == 0 or len(self.loras) == 0:
112
- return out
113
-
114
- if self.is_c_attn:
115
- # Handle Q, K, V projections separately
116
- deltas = []
117
- for i in range(3):
118
- delta = self.loras[repeat_idx][i].delta() # (split_size, in)
119
- if delta is not None:
120
- delta_t = delta.T # (in, split_size)
121
- deltas.append(torch.matmul(x, delta_t))
122
- if deltas:
123
- return out + torch.cat(deltas, dim=-1)
124
- return out
125
- else:
126
- delta = self.loras[repeat_idx].delta() # (out, in)
127
- if delta is not None:
128
- delta_t = delta.T # (in, out)
129
- return out + torch.matmul(x, delta_t)
130
- return out
131
-
132
- def lora_parameters(self):
133
- if self.is_c_attn:
134
- for lora_group in self.loras:
135
- for lora in lora_group:
136
- yield from lora.lora_parameters()
137
- else:
138
- for lora in self.loras:
139
- yield from lora.lora_parameters()