IEIT-Yuan commited on
Commit
08191f1
1 Parent(s): 99d0f96

Upload 2 files

Browse files
Files changed (2) hide show
  1. yuan_hf_model.py +59 -11
  2. yuan_hf_model_cpu.py +60 -12
yuan_hf_model.py CHANGED
@@ -25,7 +25,6 @@ import torch
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
- from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaRotaryEmbedding
29
  from transformers.activations import ACT2FN
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
@@ -58,9 +57,7 @@ class LocalizedFiltering(torch.nn.Module):
58
 
59
  self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
60
  self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
61
-
62
- #Use the same RMSNorm as llama
63
- self.output_layernorm = LlamaRMSNorm(self.embed_dim)
64
 
65
  def _train_forward(self, inputs):
66
  inputs = inputs.transpose(0,1)
@@ -197,7 +194,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
197
  k_embed = (k * cos) + (rotate_half(k) * sin)
198
  return q_embed, k_embed
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  class YuanMLP(nn.Module):
203
  def __init__(
@@ -240,8 +291,7 @@ class YuanAttention(nn.Module):
240
  )
241
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
242
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
243
- #Use the same RoataryEmbedding as llama
244
- self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
245
  if self.use_shareqk:
246
  self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
247
  self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
@@ -393,9 +443,8 @@ class YuanDecoderLayer(nn.Module):
393
  intermediate_size=config.intermediate_size,
394
  hidden_act=config.hidden_act,
395
  )
396
- #Use the same RMSNorm as llama
397
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
398
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
399
 
400
  def forward(
401
  self,
@@ -583,8 +632,7 @@ class YuanModel(YuanPreTrainedModel):
583
  self.reset_position_ids = config.reset_position_ids
584
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
585
  self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
586
- #Use the same RMSNorm as llama
587
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
588
  self.gradient_checkpointing = False
589
  # Initialize weights and apply final processing
590
  self.post_init()
 
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
30
  from transformers.modeling_utils import PreTrainedModel
 
57
 
58
  self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
59
  self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
60
+ self.output_layernorm = YuanRMSNorm(self.embed_dim)
 
 
61
 
62
  def _train_forward(self, inputs):
63
  inputs = inputs.transpose(0,1)
 
194
  k_embed = (k * cos) + (rotate_half(k) * sin)
195
  return q_embed, k_embed
196
 
197
+ class YuanRMSNorm(nn.Module):
198
+ def __init__(self, hidden_size, eps=1e-6):
199
+ """
200
+ YuanRMSNorm is equivalent to LlamaRMSNorm
201
+ """
202
+ super().__init__()
203
+ self.weight = nn.Parameter(torch.ones(hidden_size))
204
+ self.variance_epsilon = eps
205
+
206
+ def forward(self, hidden_states):
207
+ input_dtype = hidden_states.dtype
208
+ hidden_states = hidden_states.to(torch.float32)
209
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
210
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
211
+ return self.weight * hidden_states.to(input_dtype)
212
+
213
+ class YuanRotaryEmbedding(torch.nn.Module):
214
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
215
+
216
+ """
217
+ YuanRotaryEmbedding is equivalent to LlamaRotaryEmbedding in transformers v4.36
218
+ """
219
+
220
+ super().__init__()
221
+
222
+ self.dim = dim
223
+ self.max_position_embeddings = max_position_embeddings
224
+ self.base = base
225
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
226
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
227
+
228
+ # Build here to make `torch.jit.trace` work.
229
+ self._set_cos_sin_cache(
230
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
231
+ )
232
 
233
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
234
+ self.max_seq_len_cached = seq_len
235
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
236
+
237
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
238
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
239
+ emb = torch.cat((freqs, freqs), dim=-1)
240
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
241
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
242
+
243
+ def forward(self, x, seq_len=None):
244
+ # x: [bs, num_attention_heads, seq_len, head_size]
245
+ if seq_len > self.max_seq_len_cached:
246
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
247
+
248
+ return (
249
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
250
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
251
+ )
252
 
253
  class YuanMLP(nn.Module):
254
  def __init__(
 
291
  )
292
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
293
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
294
+ self.rotary_emb = YuanRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
 
295
  if self.use_shareqk:
296
  self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
297
  self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
 
443
  intermediate_size=config.intermediate_size,
444
  hidden_act=config.hidden_act,
445
  )
446
+ self.input_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+ self.post_attention_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
448
 
449
  def forward(
450
  self,
 
632
  self.reset_position_ids = config.reset_position_ids
633
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
634
  self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
635
+ self.norm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
636
  self.gradient_checkpointing = False
637
  # Initialize weights and apply final processing
638
  self.post_init()
yuan_hf_model_cpu.py CHANGED
@@ -25,7 +25,6 @@ import torch
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
- from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaRotaryEmbedding
29
  from transformers.activations import ACT2FN
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
@@ -58,9 +57,7 @@ class LocalizedFiltering(torch.nn.Module):
58
 
59
  self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
60
  self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
61
-
62
- #Use the same RMSNorm as llama
63
- self.output_layernorm = LlamaRMSNorm(self.embed_dim)
64
 
65
  def _train_forward(self, inputs):
66
  inputs = inputs.transpose(0,1)
@@ -197,7 +194,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
197
  k_embed = (k * cos) + (rotate_half(k) * sin)
198
  return q_embed, k_embed
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  class YuanMLP(nn.Module):
203
  def __init__(
@@ -240,8 +291,7 @@ class YuanAttention(nn.Module):
240
  )
241
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
242
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
243
- #Use the same RoataryEmbedding as llama
244
- self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
245
  if self.use_shareqk:
246
  self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
247
  self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
@@ -268,7 +318,7 @@ class YuanAttention(nn.Module):
268
  is_first_step = False
269
  if use_cache:
270
  if past_key_value is None:
271
- # inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
272
  inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
273
  is_first_step = True
274
  else:
@@ -393,9 +443,8 @@ class YuanDecoderLayer(nn.Module):
393
  intermediate_size=config.intermediate_size,
394
  hidden_act=config.hidden_act,
395
  )
396
- #Use the same RMSNorm as llama
397
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
398
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
399
 
400
  def forward(
401
  self,
@@ -583,8 +632,7 @@ class YuanModel(YuanPreTrainedModel):
583
  self.reset_position_ids = config.reset_position_ids
584
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
585
  self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
586
- #Use the same RMSNorm as llama
587
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
588
  self.gradient_checkpointing = False
589
  # Initialize weights and apply final processing
590
  self.post_init()
 
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
30
  from transformers.modeling_utils import PreTrainedModel
 
57
 
58
  self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
59
  self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
60
+ self.output_layernorm = YuanRMSNorm(self.embed_dim)
 
 
61
 
62
  def _train_forward(self, inputs):
63
  inputs = inputs.transpose(0,1)
 
194
  k_embed = (k * cos) + (rotate_half(k) * sin)
195
  return q_embed, k_embed
196
 
197
+ class YuanRMSNorm(nn.Module):
198
+ def __init__(self, hidden_size, eps=1e-6):
199
+ """
200
+ YuanRMSNorm is equivalent to LlamaRMSNorm
201
+ """
202
+ super().__init__()
203
+ self.weight = nn.Parameter(torch.ones(hidden_size))
204
+ self.variance_epsilon = eps
205
+
206
+ def forward(self, hidden_states):
207
+ input_dtype = hidden_states.dtype
208
+ hidden_states = hidden_states.to(torch.float32)
209
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
210
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
211
+ return self.weight * hidden_states.to(input_dtype)
212
+
213
+ class YuanRotaryEmbedding(torch.nn.Module):
214
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
215
+
216
+ """
217
+ YuanRotaryEmbedding is equivalent to LlamaRotaryEmbedding in transformers v4.36
218
+ """
219
+
220
+ super().__init__()
221
+
222
+ self.dim = dim
223
+ self.max_position_embeddings = max_position_embeddings
224
+ self.base = base
225
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
226
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
227
+
228
+ # Build here to make `torch.jit.trace` work.
229
+ self._set_cos_sin_cache(
230
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
231
+ )
232
 
233
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
234
+ self.max_seq_len_cached = seq_len
235
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
236
+
237
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
238
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
239
+ emb = torch.cat((freqs, freqs), dim=-1)
240
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
241
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
242
+
243
+ def forward(self, x, seq_len=None):
244
+ # x: [bs, num_attention_heads, seq_len, head_size]
245
+ if seq_len > self.max_seq_len_cached:
246
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
247
+
248
+ return (
249
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
250
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
251
+ )
252
 
253
  class YuanMLP(nn.Module):
254
  def __init__(
 
291
  )
292
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
293
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
294
+ self.rotary_emb = YuanRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
 
295
  if self.use_shareqk:
296
  self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
297
  self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
 
318
  is_first_step = False
319
  if use_cache:
320
  if past_key_value is None:
321
+ #inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
322
  inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
323
  is_first_step = True
324
  else:
 
443
  intermediate_size=config.intermediate_size,
444
  hidden_act=config.hidden_act,
445
  )
446
+ self.input_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+ self.post_attention_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
448
 
449
  def forward(
450
  self,
 
632
  self.reset_position_ids = config.reset_position_ids
633
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
634
  self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
635
+ self.norm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
636
  self.gradient_checkpointing = False
637
  # Initialize weights and apply final processing
638
  self.post_init()