BucketOfFish commited on
Commit
0f3418e
1 Parent(s): 10aca20

Got model running, but results are incorrect

Browse files
Files changed (4) hide show
  1. attention.py +3 -6
  2. config.json +2 -2
  3. phi2_configuration.py +18 -18
  4. phi2_model.py +1 -1
attention.py CHANGED
@@ -28,7 +28,7 @@ class RotaryEmbedding(nn.Module):
28
  d_rotary: int,
29
  rotary_base: float = 10000.0,
30
  initial_cos_sin_cache_len: int = 2048,
31
- device: torch.device | None = None,
32
  ) -> None:
33
  super().__init__()
34
  self.d_rotary = d_rotary
@@ -52,7 +52,6 @@ class RotaryEmbedding(nn.Module):
52
  torch.arange(
53
  start=0,
54
  end=self.d_rotary,
55
- step=2,
56
  device=self.device,
57
  dtype=self.dtype,
58
  ) / self.d_rotary
@@ -61,8 +60,8 @@ class RotaryEmbedding(nn.Module):
61
  # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
62
  # TODO: does this matter if I'm disabling torch.autocast?
63
  m_theta_i = torch.outer(m, theta_i)
64
- self._cos_cached = torch.cos(m_theta_i).to(self.dtype)
65
- self._sin_cached = torch.sin(m_theta_i).to(self.dtype)
66
 
67
  # TODO: scale_base caching is labelled as not yet done in Phi2
68
  """
@@ -108,8 +107,6 @@ class RotaryEmbedding(nn.Module):
108
  if (
109
  not self._max_seqlen
110
  or self._max_seqlen < x.shape[1] + seqlen_offset
111
- or self._cos_cached.device != x.device
112
- or self._cos_cached.dtype != x.dtype
113
  or (self.training and self._cos_cached.is_inference())
114
  ):
115
  self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
 
28
  d_rotary: int,
29
  rotary_base: float = 10000.0,
30
  initial_cos_sin_cache_len: int = 2048,
31
+ device: torch.device = "cuda",
32
  ) -> None:
33
  super().__init__()
34
  self.d_rotary = d_rotary
 
52
  torch.arange(
53
  start=0,
54
  end=self.d_rotary,
 
55
  device=self.device,
56
  dtype=self.dtype,
57
  ) / self.d_rotary
 
60
  # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
61
  # TODO: does this matter if I'm disabling torch.autocast?
62
  m_theta_i = torch.outer(m, theta_i)
63
+ self._cos_cached = torch.cos(m_theta_i).to(self.dtype).to(self.device)
64
+ self._sin_cached = torch.sin(m_theta_i).to(self.dtype).to(self.device)
65
 
66
  # TODO: scale_base caching is labelled as not yet done in Phi2
67
  """
 
107
  if (
108
  not self._max_seqlen
109
  or self._max_seqlen < x.shape[1] + seqlen_offset
 
 
110
  or (self.training and self._cos_cached.is_inference())
111
  ):
112
  self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
config.json CHANGED
@@ -17,8 +17,8 @@
17
  "vocab_chunk_for_gpu_efficiency": 64,
18
  "initial_cos_sin_cache_len": 2048,
19
  "d_embedding": 2560,
20
- "n_blocks": 32,
21
- "n_heads": 32,
22
  "use_flash_attn": false,
23
  "use_flash_rotary": false,
24
  "use_fused_dense": false,
 
17
  "vocab_chunk_for_gpu_efficiency": 64,
18
  "initial_cos_sin_cache_len": 2048,
19
  "d_embedding": 2560,
20
+ "n_attn_blocks": 32,
21
+ "n_attn_heads": 32,
22
  "use_flash_attn": false,
23
  "use_flash_rotary": false,
24
  "use_fused_dense": false,
phi2_configuration.py CHANGED
@@ -8,27 +8,27 @@ class Phi2Config(PretrainedConfig):
8
  "max_position_embeddings": "initial_cos_sin_cache_len",
9
  "hidden_size": "d_embedding",
10
  "num_attention_heads": "n_attn_heads",
11
- "num_hidden_layers": "n_blocks",
12
  }
13
 
14
  def __init__(
15
  self,
16
- vocab_size: int = 50295, # this includes the extra tokens included by Phi2 in tokenizer_config.json
17
- vocab_chunk_for_gpu_efficiency: int = 64,
18
- initial_cos_sin_cache_len: int = 2048,
19
- d_embedding: int = 1024, # 2560?
20
- n_blocks: int = 20, # 32?
21
- n_attn_heads: int = 16, # 32?
22
- use_flash_attn: bool = False,
23
- use_flash_rotary: bool = False,
24
- use_fused_dense: bool = False,
25
- attn_pdrop: float = 0.0,
26
- embd_pdrop: float = 0.0,
27
- resid_pdrop: float = 0.0,
28
- layer_norm_epsilon: float = 1e-5,
29
- weight_initialization_range: float = 0.02,
30
- tie_word_embeddings: bool = False, # whether embedding weights are shared between the encoder and decoder
31
- checkpointing: bool = False, # whether to use gradient checkpointing to reduce memory usage (I think)
32
  **kwargs
33
  ) -> None:
34
  self.vocab_size = (
@@ -38,7 +38,7 @@ class Phi2Config(PretrainedConfig):
38
  )
39
  self.initial_cos_sin_cache_len = initial_cos_sin_cache_len
40
  self.d_embedding = d_embedding
41
- self.n_blocks = n_blocks
42
  self.n_attn_heads = n_attn_heads
43
  self.use_flash_attn = use_flash_attn
44
  self.use_flash_rotary = use_flash_rotary
 
8
  "max_position_embeddings": "initial_cos_sin_cache_len",
9
  "hidden_size": "d_embedding",
10
  "num_attention_heads": "n_attn_heads",
11
+ "num_hidden_layers": "n_attn_blocks",
12
  }
13
 
14
  def __init__(
15
  self,
16
+ vocab_size: int, # this includes the extra tokens included by Phi2 in tokenizer_config.json
17
+ vocab_chunk_for_gpu_efficiency: int,
18
+ initial_cos_sin_cache_len: int,
19
+ d_embedding: int,
20
+ n_attn_blocks: int,
21
+ n_attn_heads: int,
22
+ use_flash_attn: bool,
23
+ use_flash_rotary: bool,
24
+ use_fused_dense: bool,
25
+ attn_pdrop: float,
26
+ embd_pdrop: float,
27
+ resid_pdrop: float,
28
+ layer_norm_epsilon: float,
29
+ weight_initialization_range: float,
30
+ tie_word_embeddings: bool, # whether embedding weights are shared between the encoder and decoder
31
+ checkpointing: bool, # whether to use gradient checkpointing to reduce memory usage (I think)
32
  **kwargs
33
  ) -> None:
34
  self.vocab_size = (
 
38
  )
39
  self.initial_cos_sin_cache_len = initial_cos_sin_cache_len
40
  self.d_embedding = d_embedding
41
+ self.n_attn_blocks = n_attn_blocks
42
  self.n_attn_heads = n_attn_heads
43
  self.use_flash_attn = use_flash_attn
44
  self.use_flash_rotary = use_flash_rotary
phi2_model.py CHANGED
@@ -106,7 +106,7 @@ class Phi2Model(Phi2PreTrainedModel):
106
  use_fused_dense=config.use_fused_dense,
107
  checkpointing=config.checkpointing,
108
  )
109
- for i in range(config.n_blocks)
110
  ])
111
  self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
112
  self.post_init() # calls self._init_weights() for all modules
 
106
  use_fused_dense=config.use_fused_dense,
107
  checkpointing=config.checkpointing,
108
  )
109
+ for i in range(config.n_attn_blocks)
110
  ])
111
  self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
112
  self.post_init() # calls self._init_weights() for all modules