Crystalcareai commited on
Commit
8f32857
·
verified ·
1 Parent(s): fe54712

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +6 -17
modeling_gemmoe.py CHANGED
@@ -85,14 +85,9 @@ class GemmoeRMSNorm(nn.Module):
85
  self.eps = eps
86
  self.weight = nn.Parameter(torch.zeros(dim))
87
 
88
- def _norm(self, x):
89
  x_float = x.float()
90
  normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
91
- return normed_x
92
-
93
- def forward(self, x):
94
- normed_x = self._norm(x)
95
- # Downcast the result to the original dtype at the end
96
  normed_x = normed_x.type_as(x)
97
  return normed_x * (self.weight + 1)
98
 
@@ -108,11 +103,10 @@ class GemmoeRotaryEmbedding(nn.Module):
108
 
109
  def _set_cos_sin_cache(self, seq_len, device, dtype):
110
  self.max_seq_len_cached = seq_len
111
- freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
112
  timescale = self.base ** freq_exponents
113
- positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
114
- radians_new = positions[..., None] / timescale[None, None, :]
115
- radians_new = radians_new.squeeze(0)
116
  emb = torch.cat((radians_new, radians_new), dim=-1)
117
  cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
118
  sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
@@ -120,20 +114,15 @@ class GemmoeRotaryEmbedding(nn.Module):
120
  self.register_buffer("sin_cached", sin, persistent=False)
121
 
122
  def forward(self, x, position_ids=None, seq_len=None):
123
- if seq_len is None:
124
- seq_len = x.size(2)
125
  if seq_len > self.max_seq_len_cached:
126
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
127
  return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
128
 
129
- # Copied from transformers.models.llama.modeling_llama.rotate_half
130
  def rotate_half(x):
131
- """Rotates half the hidden dims of the input."""
132
- x1 = x[..., : x.shape[-1] // 2]
133
- x2 = x[..., x.shape[-1] // 2 :]
134
  return torch.cat((-x2, x1), dim=-1)
135
 
136
-
137
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
138
  seq_len, dim = q.shape[-2], q.shape[-1]
139
  cos = cos[:seq_len].view(1, 1, seq_len, dim)
 
85
  self.eps = eps
86
  self.weight = nn.Parameter(torch.zeros(dim))
87
 
88
+ def forward(self, x):
89
  x_float = x.float()
90
  normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
 
 
 
 
 
91
  normed_x = normed_x.type_as(x)
92
  return normed_x * (self.weight + 1)
93
 
 
103
 
104
  def _set_cos_sin_cache(self, seq_len, device, dtype):
105
  self.max_seq_len_cached = seq_len
106
+ freq_exponents = (2.0 / self.dim) * torch.arange(self.dim // 2, dtype=torch.float32, device="cpu")
107
  timescale = self.base ** freq_exponents
108
+ positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32)
109
+ radians_new = positions.view(-1, 1) / timescale.view(1, -1)
 
110
  emb = torch.cat((radians_new, radians_new), dim=-1)
111
  cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
112
  sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
 
114
  self.register_buffer("sin_cached", sin, persistent=False)
115
 
116
  def forward(self, x, position_ids=None, seq_len=None):
117
+ seq_len = x.size(2) if seq_len is None else seq_len
 
118
  if seq_len > self.max_seq_len_cached:
119
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
120
  return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
121
 
 
122
  def rotate_half(x):
123
+ x1, x2 = x.chunk(2, dim=-1)
 
 
124
  return torch.cat((-x2, x1), dim=-1)
125
 
 
126
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
127
  seq_len, dim = q.shape[-2], q.shape[-1]
128
  cos = cos[:seq_len].view(1, 1, seq_len, dim)