Guanzheng commited on
Commit
e08a60e
·
verified ·
1 Parent(s): 1f7f9c4

Update clex_layer.py

Browse files
Files changed (1) hide show
  1. clex_layer.py +44 -28
clex_layer.py CHANGED
@@ -1,23 +1,34 @@
1
  import torch
2
- import torch.nn as nn
3
  from torchdiffeq import odeint
4
 
5
-
6
 
7
  import math
8
 
 
 
 
9
  class ODELinear(nn.Module):
10
  def __init__(
11
  self,
12
  dim: int,
13
  factor,
 
 
14
  **kwargs
15
  ):
16
  super().__init__()
17
- self.ode_up_proj = nn.Parameter(torch.empty(dim//2, factor*dim).to(torch.float32))
18
- self.ode_down_proj = nn.Parameter(torch.empty(factor*dim, dim//2).to(torch.float32))
19
  self.dim = dim
20
- self.act = torch.nn.SiLU()
 
 
 
 
 
 
21
  self.reset_parameters()
22
 
23
  def reset_parameters(self):
@@ -36,15 +47,20 @@ class ODELinear(nn.Module):
36
  return delta_ntk_freq.to(device, dtype=dtype), ntk_inv_freq.to(device, dtype=dtype)
37
 
38
  def forward(self, t, x: torch.Tensor):
39
- delta_time, time = self.get_time_embedding(t, device=x.device, dtype=x.dtype)
 
 
40
  x = x + torch.log(time)
41
  time_embed = delta_time / time
42
- delta_inv_freq = self.act(x @ self.ode_up_proj.float()) @ self.ode_down_proj.float() + time_embed
 
43
  return delta_inv_freq
44
 
45
 
46
 
47
- class LlamaCLEXScalingRotaryEmbedding(nn.Module):
 
 
48
 
49
  def __init__(self, dim, max_position_embeddings=2048, rope_scaling=None, base=10000, device=None) -> None:
50
  super().__init__()
@@ -56,22 +72,21 @@ class LlamaCLEXScalingRotaryEmbedding(nn.Module):
56
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
57
  self.register_buffer("inv_freq", inv_freq)
58
 
59
- self.proj_func = ODELinear(dim, rope_scaling["param_factor"])
60
  self.rope_cached = None
61
  self.max_t_cached = 0
62
  self.freq_cached = None
63
- self.time_dt = 0.01
64
  self.ode_args = {
65
  "method": "rk4",
66
  "options": {"step_size": self.time_dt},
67
  }
68
 
69
  def sample_random_times(self, max_t, device):
70
- return torch.randint(2, max_t, (1,), dtype = torch.long, device=device)
71
 
72
  def get_random_position_ids(self, n=2048, max=8192):
73
  positions = torch.randperm(max)[:n].sort().values
74
- # positions = positions.to(device=device)
75
  return positions
76
 
77
 
@@ -80,24 +95,24 @@ class LlamaCLEXScalingRotaryEmbedding(nn.Module):
80
  self.proj_func, torch.log(self.inv_freq.to(device, dtype=torch.float32)), time_grid, **self.ode_args
81
  )
82
  if time_grid.size(0) == 2:
83
- training
84
  scale_inv_freq = torch.exp(solution[1])
85
- # print(time_grid[1].tolist(), torch.sum(scale_inv_freq).tolist(), torch.sum(self.proj_func.ode_down_proj).tolist())
86
  freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
87
  else:
88
  scale_inv_freq = torch.exp(solution)
89
- freqs = torch.einsum('i, kl -> kil', ex_positions, scale_inv_freq)
90
  embed = torch.cat((freqs,freqs), dim=-1)
91
  return embed
92
 
93
 
94
 
95
- def forward(self, device, dtype, seq_len, do_train=False):
96
  device = self.proj_func.ode_up_proj.device
 
97
  scale_factor = seq_len // self.max_position_embeddings
98
  if do_train:
99
  t_val = self.sample_random_times(self.max_t+1, device)[0]
100
- import math
 
101
  sampled_position_ids = self.get_random_position_ids(n=seq_len-2, max=seq_len*t_val-2).float()
102
  ex_positions = torch.cat([
103
  torch.tensor([0]),
@@ -115,24 +130,25 @@ class LlamaCLEXScalingRotaryEmbedding(nn.Module):
115
  scale_inv_freq = self.inv_freq.to(device)
116
  freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
117
  embed = torch.cat((freqs,freqs), dim=-1)
118
- cos, sin = embed.cos()[None, None, :, :], embed.sin()[None, None, :, :]
119
  elif do_train:
120
  time_grid = torch.tensor([1.0, t_val]).float().to(device)
121
  embed = self.get_continuous_freq(time_grid, ex_positions, device)
122
- cos, sin = embed.cos()[None, None, :, :], embed.sin()[None, None, :, :]
123
  else:
124
- if t_val > self.max_t_cached:
125
- time_grid = torch.arange(1.0, self.max_t + 1.0, dtype=torch.float32).to(device)
126
- if self.freq_cached is None:
127
- self.freq_cached = self.get_continuous_freq(time_grid, ex_positions, device)
128
- embed = self.freq_cached[int(t_val)-1.0]
129
- self.rope_cached = torch.cat((embed.cos()[None, None, None, :, :], embed.sin()[None, None, None, :, :]), dim=0)
 
 
130
  self.max_t_cached = t_val
131
  cos, sin = self.rope_cached
132
-
133
  return torch.cat(
134
- (cos[None, :, :, :seq_len, ...].to(dtype=dtype),
135
- sin[None, :, :, :seq_len, ...].to(dtype=dtype)),
136
  dim=0
137
  )
138
 
 
1
  import torch
2
+ from torch import nn
3
  from torchdiffeq import odeint
4
 
5
+ import wandb
6
 
7
  import math
8
 
9
+
10
+
11
+
12
  class ODELinear(nn.Module):
13
  def __init__(
14
  self,
15
  dim: int,
16
  factor,
17
+ act,
18
+ base=10000,
19
  **kwargs
20
  ):
21
  super().__init__()
22
+ self.ode_up_proj = nn.Parameter(torch.empty(dim//2, factor*dim))
23
+ self.ode_down_proj = nn.Parameter(torch.empty(factor*dim, dim//2))
24
  self.dim = dim
25
+ self.base = base
26
+ if act == "tanh":
27
+ self.act = torch.nn.Tanh()
28
+ elif act == "silu":
29
+ self.act = torch.nn.SiLU()
30
+ else:
31
+ raise ValueError(f"act must be one of ['tanh', 'silu'], got {act}")
32
  self.reset_parameters()
33
 
34
  def reset_parameters(self):
 
47
  return delta_ntk_freq.to(device, dtype=dtype), ntk_inv_freq.to(device, dtype=dtype)
48
 
49
  def forward(self, t, x: torch.Tensor):
50
+
51
+ device = x.device
52
+ delta_time, time = self.get_time_embedding(t.to(device), device=device, dtype=x.dtype)
53
  x = x + torch.log(time)
54
  time_embed = delta_time / time
55
+ delta_inv_freq = self.act(x @ self.ode_up_proj.float()) @ self.ode_down_proj.float()
56
+ delta_inv_freq = delta_inv_freq + time_embed
57
  return delta_inv_freq
58
 
59
 
60
 
61
+
62
+
63
+ class CLEXScalingRotaryEmbedding(nn.Module):
64
 
65
  def __init__(self, dim, max_position_embeddings=2048, rope_scaling=None, base=10000, device=None) -> None:
66
  super().__init__()
 
72
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
73
  self.register_buffer("inv_freq", inv_freq)
74
 
75
+ self.proj_func = ODELinear(dim, rope_scaling["param_factor"], rope_scaling["act"], base)
76
  self.rope_cached = None
77
  self.max_t_cached = 0
78
  self.freq_cached = None
79
+ self.time_dt = rope_scaling["time_dt"]
80
  self.ode_args = {
81
  "method": "rk4",
82
  "options": {"step_size": self.time_dt},
83
  }
84
 
85
  def sample_random_times(self, max_t, device):
86
+ return torch.randint(1, max_t, (1,), dtype = torch.long, device=device)
87
 
88
  def get_random_position_ids(self, n=2048, max=8192):
89
  positions = torch.randperm(max)[:n].sort().values
 
90
  return positions
91
 
92
 
 
95
  self.proj_func, torch.log(self.inv_freq.to(device, dtype=torch.float32)), time_grid, **self.ode_args
96
  )
97
  if time_grid.size(0) == 2:
 
98
  scale_inv_freq = torch.exp(solution[1])
 
99
  freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
100
  else:
101
  scale_inv_freq = torch.exp(solution)
102
+ return scale_inv_freq
103
  embed = torch.cat((freqs,freqs), dim=-1)
104
  return embed
105
 
106
 
107
 
108
+ def forward(self, input_embeds, seq_len, do_train=False):
109
  device = self.proj_func.ode_up_proj.device
110
+ dtype = input_embeds.dtype
111
  scale_factor = seq_len // self.max_position_embeddings
112
  if do_train:
113
  t_val = self.sample_random_times(self.max_t+1, device)[0]
114
+ if scale_factor < 1.0:
115
+ scale_factor = 1
116
  sampled_position_ids = self.get_random_position_ids(n=seq_len-2, max=seq_len*t_val-2).float()
117
  ex_positions = torch.cat([
118
  torch.tensor([0]),
 
130
  scale_inv_freq = self.inv_freq.to(device)
131
  freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
132
  embed = torch.cat((freqs,freqs), dim=-1)
133
+ cos, sin = embed.cos(), embed.sin()
134
  elif do_train:
135
  time_grid = torch.tensor([1.0, t_val]).float().to(device)
136
  embed = self.get_continuous_freq(time_grid, ex_positions, device)
137
+ cos, sin = embed.cos(), embed.sin()
138
  else:
139
+ if self.freq_cached is None:
140
+ time_grid = torch.arange(1.0, self.max_t+1.0, dtype=torch.float32).to(device)
141
+ self.freq_cached = self.get_continuous_freq(time_grid, ex_positions, device)
142
+ if t_val != self.max_t_cached:
143
+ scale_inv_freq = self.freq_cached[int(t_val-1.0)]
144
+ freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
145
+ embed = torch.cat((freqs,freqs), dim=-1)
146
+ self.rope_cached = torch.cat((embed.cos()[None, :, :], embed.sin()[None, :, :]), dim=0)
147
  self.max_t_cached = t_val
148
  cos, sin = self.rope_cached
 
149
  return torch.cat(
150
+ (cos[None, :seq_len].to(dtype=dtype),
151
+ sin[None, :seq_len].to(dtype=dtype)),
152
  dim=0
153
  )
154