Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +6 -17
modeling_gemmoe.py
CHANGED
@@ -85,14 +85,9 @@ class GemmoeRMSNorm(nn.Module):
|
|
85 |
self.eps = eps
|
86 |
self.weight = nn.Parameter(torch.zeros(dim))
|
87 |
|
88 |
-
def
|
89 |
x_float = x.float()
|
90 |
normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
|
91 |
-
return normed_x
|
92 |
-
|
93 |
-
def forward(self, x):
|
94 |
-
normed_x = self._norm(x)
|
95 |
-
# Downcast the result to the original dtype at the end
|
96 |
normed_x = normed_x.type_as(x)
|
97 |
return normed_x * (self.weight + 1)
|
98 |
|
@@ -108,11 +103,10 @@ class GemmoeRotaryEmbedding(nn.Module):
|
|
108 |
|
109 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
110 |
self.max_seq_len_cached = seq_len
|
111 |
-
freq_exponents = (2.0 / self.dim) *
|
112 |
timescale = self.base ** freq_exponents
|
113 |
-
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32)
|
114 |
-
radians_new = positions
|
115 |
-
radians_new = radians_new.squeeze(0)
|
116 |
emb = torch.cat((radians_new, radians_new), dim=-1)
|
117 |
cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
|
118 |
sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
|
@@ -120,20 +114,15 @@ class GemmoeRotaryEmbedding(nn.Module):
|
|
120 |
self.register_buffer("sin_cached", sin, persistent=False)
|
121 |
|
122 |
def forward(self, x, position_ids=None, seq_len=None):
|
123 |
-
if seq_len is None
|
124 |
-
seq_len = x.size(2)
|
125 |
if seq_len > self.max_seq_len_cached:
|
126 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
127 |
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
128 |
|
129 |
-
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
130 |
def rotate_half(x):
|
131 |
-
|
132 |
-
x1 = x[..., : x.shape[-1] // 2]
|
133 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
134 |
return torch.cat((-x2, x1), dim=-1)
|
135 |
|
136 |
-
|
137 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
138 |
seq_len, dim = q.shape[-2], q.shape[-1]
|
139 |
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|
|
|
85 |
self.eps = eps
|
86 |
self.weight = nn.Parameter(torch.zeros(dim))
|
87 |
|
88 |
+
def forward(self, x):
|
89 |
x_float = x.float()
|
90 |
normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
|
|
|
|
|
|
|
91 |
normed_x = normed_x.type_as(x)
|
92 |
return normed_x * (self.weight + 1)
|
93 |
|
|
|
103 |
|
104 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
105 |
self.max_seq_len_cached = seq_len
|
106 |
+
freq_exponents = (2.0 / self.dim) * torch.arange(self.dim // 2, dtype=torch.float32, device="cpu")
|
107 |
timescale = self.base ** freq_exponents
|
108 |
+
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32)
|
109 |
+
radians_new = positions.view(-1, 1) / timescale.view(1, -1)
|
|
|
110 |
emb = torch.cat((radians_new, radians_new), dim=-1)
|
111 |
cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
|
112 |
sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
|
|
|
114 |
self.register_buffer("sin_cached", sin, persistent=False)
|
115 |
|
116 |
def forward(self, x, position_ids=None, seq_len=None):
|
117 |
+
seq_len = x.size(2) if seq_len is None else seq_len
|
|
|
118 |
if seq_len > self.max_seq_len_cached:
|
119 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
120 |
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
121 |
|
|
|
122 |
def rotate_half(x):
|
123 |
+
x1, x2 = x.chunk(2, dim=-1)
|
|
|
|
|
124 |
return torch.cat((-x2, x1), dim=-1)
|
125 |
|
|
|
126 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
127 |
seq_len, dim = q.shape[-2], q.shape[-1]
|
128 |
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|