Commit
•
455129a
1
Parent(s):
76e8ee6
Edited comments
Browse files- attention.py +2 -10
- phi2_model.py +3 -3
attention.py
CHANGED
@@ -19,9 +19,7 @@ except ImportError:
|
|
19 |
|
20 |
|
21 |
class RotaryEmbedding(nn.Module):
|
22 |
-
"""Rotary positional embedding (RoPE)
|
23 |
-
See https://www.youtube.com/watch?v=C6rV8BsrrCc
|
24 |
-
"""
|
25 |
|
26 |
def __init__(
|
27 |
self,
|
@@ -129,8 +127,6 @@ class RotaryEmbedding(nn.Module):
|
|
129 |
|
130 |
|
131 |
class SelfAttention(nn.Module):
|
132 |
-
"""Self-attention layer, taken from Phi2 model."""
|
133 |
-
|
134 |
def __init__(
|
135 |
self,
|
136 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
@@ -174,8 +170,6 @@ class SelfAttention(nn.Module):
|
|
174 |
|
175 |
|
176 |
class CrossAttention(nn.Module):
|
177 |
-
"""Cross-attention layer, taken from Phi2 model."""
|
178 |
-
|
179 |
def __init__(
|
180 |
self,
|
181 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
@@ -225,8 +219,6 @@ class CrossAttention(nn.Module):
|
|
225 |
|
226 |
|
227 |
class MLP(nn.Module):
|
228 |
-
"""Taken from Phi2 as well."""
|
229 |
-
|
230 |
def __init__(
|
231 |
self,
|
232 |
d_embedding: int,
|
@@ -489,7 +481,7 @@ class MHA(nn.Module):
|
|
489 |
|
490 |
|
491 |
class ParallelAttentionBlock(nn.Module):
|
492 |
-
"""
|
493 |
|
494 |
def __init__(
|
495 |
self,
|
|
|
19 |
|
20 |
|
21 |
class RotaryEmbedding(nn.Module):
|
22 |
+
"""Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc"""
|
|
|
|
|
23 |
|
24 |
def __init__(
|
25 |
self,
|
|
|
127 |
|
128 |
|
129 |
class SelfAttention(nn.Module):
|
|
|
|
|
130 |
def __init__(
|
131 |
self,
|
132 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
|
|
170 |
|
171 |
|
172 |
class CrossAttention(nn.Module):
|
|
|
|
|
173 |
def __init__(
|
174 |
self,
|
175 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
|
|
219 |
|
220 |
|
221 |
class MLP(nn.Module):
|
|
|
|
|
222 |
def __init__(
|
223 |
self,
|
224 |
d_embedding: int,
|
|
|
481 |
|
482 |
|
483 |
class ParallelAttentionBlock(nn.Module):
|
484 |
+
"""Calculates attention and MLP in parallel."""
|
485 |
|
486 |
def __init__(
|
487 |
self,
|
phi2_model.py
CHANGED
@@ -37,7 +37,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
|
|
37 |
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
|
38 |
kv_cache: KVCache | None = None,
|
39 |
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
|
40 |
-
**kwargs,
|
41 |
) -> dict[str, Any]:
|
42 |
if not kv_cache:
|
43 |
kv_cache = KVCache(
|
@@ -61,7 +61,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
|
|
61 |
|
62 |
|
63 |
class Embedding(nn.Module):
|
64 |
-
"""Token embedding with dropout
|
65 |
|
66 |
def __init__(
|
67 |
self,
|
@@ -150,7 +150,7 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
150 |
kv_cache: KVCache | None = None,
|
151 |
key_padding_mask: torch.BoolTensor | None = None,
|
152 |
labels: torch.LongTensor | None = None,
|
153 |
-
**kwargs,
|
154 |
) -> CausalLMOutputWithPast:
|
155 |
x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
156 |
x = self.lm_head_layer_norm(x)
|
|
|
37 |
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
|
38 |
kv_cache: KVCache | None = None,
|
39 |
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
|
40 |
+
**kwargs, # has to be here
|
41 |
) -> dict[str, Any]:
|
42 |
if not kv_cache:
|
43 |
kv_cache = KVCache(
|
|
|
61 |
|
62 |
|
63 |
class Embedding(nn.Module):
|
64 |
+
"""Token embedding with dropout."""
|
65 |
|
66 |
def __init__(
|
67 |
self,
|
|
|
150 |
kv_cache: KVCache | None = None,
|
151 |
key_padding_mask: torch.BoolTensor | None = None,
|
152 |
labels: torch.LongTensor | None = None,
|
153 |
+
**kwargs, # has to be here
|
154 |
) -> CausalLMOutputWithPast:
|
155 |
x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
156 |
x = self.lm_head_layer_norm(x)
|