BucketOfFish commited on
Commit
455129a
1 Parent(s): 76e8ee6

Edited comments

Browse files
Files changed (2) hide show
  1. attention.py +2 -10
  2. phi2_model.py +3 -3
attention.py CHANGED
@@ -19,9 +19,7 @@ except ImportError:
19
 
20
 
21
  class RotaryEmbedding(nn.Module):
22
- """Rotary positional embedding (RoPE) from Phi2.
23
- See https://www.youtube.com/watch?v=C6rV8BsrrCc
24
- """
25
 
26
  def __init__(
27
  self,
@@ -129,8 +127,6 @@ class RotaryEmbedding(nn.Module):
129
 
130
 
131
  class SelfAttention(nn.Module):
132
- """Self-attention layer, taken from Phi2 model."""
133
-
134
  def __init__(
135
  self,
136
  qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
@@ -174,8 +170,6 @@ class SelfAttention(nn.Module):
174
 
175
 
176
  class CrossAttention(nn.Module):
177
- """Cross-attention layer, taken from Phi2 model."""
178
-
179
  def __init__(
180
  self,
181
  qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
@@ -225,8 +219,6 @@ class CrossAttention(nn.Module):
225
 
226
 
227
  class MLP(nn.Module):
228
- """Taken from Phi2 as well."""
229
-
230
  def __init__(
231
  self,
232
  d_embedding: int,
@@ -489,7 +481,7 @@ class MHA(nn.Module):
489
 
490
 
491
  class ParallelAttentionBlock(nn.Module):
492
- """From Phi2. Calculates attention and MLP in parallel. See 'Simplifying Transformer Blocks', Fig. 1 'Parallel'."""
493
 
494
  def __init__(
495
  self,
 
19
 
20
 
21
  class RotaryEmbedding(nn.Module):
22
+ """Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc"""
 
 
23
 
24
  def __init__(
25
  self,
 
127
 
128
 
129
  class SelfAttention(nn.Module):
 
 
130
  def __init__(
131
  self,
132
  qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
 
170
 
171
 
172
  class CrossAttention(nn.Module):
 
 
173
  def __init__(
174
  self,
175
  qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
 
219
 
220
 
221
  class MLP(nn.Module):
 
 
222
  def __init__(
223
  self,
224
  d_embedding: int,
 
481
 
482
 
483
  class ParallelAttentionBlock(nn.Module):
484
+ """Calculates attention and MLP in parallel."""
485
 
486
  def __init__(
487
  self,
phi2_model.py CHANGED
@@ -37,7 +37,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
37
  input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
38
  kv_cache: KVCache | None = None,
39
  key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
40
- **kwargs,
41
  ) -> dict[str, Any]:
42
  if not kv_cache:
43
  kv_cache = KVCache(
@@ -61,7 +61,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
61
 
62
 
63
  class Embedding(nn.Module):
64
- """Token embedding with dropout from Phi2."""
65
 
66
  def __init__(
67
  self,
@@ -150,7 +150,7 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
150
  kv_cache: KVCache | None = None,
151
  key_padding_mask: torch.BoolTensor | None = None,
152
  labels: torch.LongTensor | None = None,
153
- **kwargs,
154
  ) -> CausalLMOutputWithPast:
155
  x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
156
  x = self.lm_head_layer_norm(x)
 
37
  input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
38
  kv_cache: KVCache | None = None,
39
  key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
40
+ **kwargs, # has to be here
41
  ) -> dict[str, Any]:
42
  if not kv_cache:
43
  kv_cache = KVCache(
 
61
 
62
 
63
  class Embedding(nn.Module):
64
+ """Token embedding with dropout."""
65
 
66
  def __init__(
67
  self,
 
150
  kv_cache: KVCache | None = None,
151
  key_padding_mask: torch.BoolTensor | None = None,
152
  labels: torch.LongTensor | None = None,
153
+ **kwargs, # has to be here
154
  ) -> CausalLMOutputWithPast:
155
  x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
156
  x = self.lm_head_layer_norm(x)