NoteDance commited on
Commit
cfbb81e
1 Parent(s): e6de0bf

Upload Phi2.py

Browse files
Files changed (1) hide show
  1. Phi2.py +196 -0
Phi2.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Dense,LayerNormalization,Embedding
3
+ from tensorflow.keras import Model
4
+ import math
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class ModelArgs:
10
+ n_positions: int = 2048
11
+ vocab_size: int = 51200
12
+ n_embd: int = 2560
13
+ n_head: int = 32
14
+ n_layer: int = 32
15
+ rotary_dim: int = 32
16
+
17
+
18
+ class RoPEAttention:
19
+ def __init__(self, dims: int, n_head: int, rotary_dim: int):
20
+ self.n_head = n_head
21
+
22
+ self.q_proj = Dense(dims)
23
+ self.k_proj = Dense(dims)
24
+ self.v_proj = Dense(dims)
25
+ self.dense = Dense(dims)
26
+
27
+ self.rope = RoPE(rotary_dim, traditional=False)
28
+
29
+ def __call__(self, x, mask=None, cache=None):
30
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
31
+
32
+ # Extract some shapes
33
+ n_head = self.n_head
34
+ B, L, D = queries.shape
35
+
36
+ # Prepare the queries, keys and values for the attention computation
37
+ queries = tf.transpose(tf.reshape(queries, (B, L, n_head, -1)), (0, 2, 1, 3))
38
+ keys = tf.transpose(tf.reshape(keys, (B, L, n_head, -1)), (0, 2, 1, 3))
39
+ values = tf.transpose(tf.reshape(values, (B, L, n_head, -1)), (0, 2, 1, 3))
40
+
41
+ # Add RoPE to the queries and keys and combine them with the cache
42
+ if cache is not None:
43
+ key_cache, value_cache = cache
44
+ queries = self.rope(queries, offset=key_cache.shape[2])
45
+ keys = self.rope(keys, offset=key_cache.shape[2])
46
+ keys = tf.concat([key_cache, keys], axis=2)
47
+ values = tf.concat([value_cache, values], axis=2)
48
+ else:
49
+ queries = self.rope(queries)
50
+ keys = self.rope(keys)
51
+
52
+ queries = tf.cast(queries, tf.float32)
53
+ keys = tf.cast(keys, tf.float32)
54
+
55
+ # Finally perform the attention computation
56
+ scale = math.sqrt(1 / queries.shape[-1])
57
+ scores = tf.matmul((queries * scale), tf.transpose(keys, (0, 1, 3, 2)))
58
+ if mask is not None:
59
+ scores = scores + mask
60
+
61
+ scores = tf.cast(tf.nn.softmax(scores, axis=-1), values.dtype)
62
+ values_hat = tf.reshape(tf.transpose(tf.matmul(scores, values), (0, 2, 1, 3)), (B, L, -1))
63
+
64
+ return self.dense(values_hat), (keys, values)
65
+
66
+
67
+ class MLP:
68
+ def __init__(self, dim, hidden_dim):
69
+ self.fc1 = Dense(hidden_dim)
70
+ self.fc2 = Dense(dim)
71
+
72
+ def __call__(self, x):
73
+ return self.fc2(tf.nn.gelu(self.fc1(x), approximate="precise"))
74
+
75
+
76
+ class ParallelBlock:
77
+ def __init__(self, config: ModelArgs):
78
+ dims = config.n_embd
79
+ mlp_dims = dims * 4
80
+ self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
81
+ self.input_layernorm = LayerNormalization()
82
+ self.mlp = MLP(dims, mlp_dims)
83
+
84
+ def __call__(self, x, mask, cache):
85
+ h = self.input_layernorm(x)
86
+ attn_h, cache = self.self_attn(h, mask, cache)
87
+ ff_h = self.mlp(h)
88
+ return attn_h + ff_h + x, cache
89
+
90
+
91
+ class Transformer:
92
+ def __init__(self, config: ModelArgs):
93
+ self.embed_tokens = Embedding(config.vocab_size, config.n_embd)
94
+ self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
95
+ self.final_layernorm = LayerNormalization()
96
+
97
+ def __call__(self, x, mask, cache):
98
+ x = self.embed_tokens(x)
99
+ if cache is None:
100
+ cache = [None] * len(self.layers)
101
+
102
+ for e, layer in enumerate(self.layers):
103
+ x, cache[e] = layer(x, mask, cache[e])
104
+ return self.final_layernorm(x), cache
105
+
106
+
107
+ class Phi2(Model):
108
+ def __init__(self, config: ModelArgs):
109
+ super(Phi2, self).__init__()
110
+ self.model = Transformer(config)
111
+ self.lm_head = Dense(config.vocab_size)
112
+
113
+ def __call__(
114
+ self,
115
+ x,
116
+ mask = None,
117
+ cache = None,
118
+ ):
119
+ mask = None
120
+ if x.shape[1] > 1:
121
+ mask = tf.fill((x.shape[1], x.shape[1]), float("-inf"))
122
+ mask = tf.linalg.band_part(mask, 0, -1)
123
+ mask = tf.linalg.set_diag(mask, tf.zeros(x.shape[1]))
124
+ mask = tf.cast(mask, x.dtype)
125
+
126
+ y, cache = self.model(x, mask, cache)
127
+ return self.lm_head(y), cache
128
+
129
+
130
+ class RoPE:
131
+ def __init__(self, dims: int, traditional: bool = False, base=None):
132
+ self.dims = dims
133
+ self.traditional = traditional
134
+ self.base = base
135
+
136
+ def _compute_rope(self, costheta, sintheta, x):
137
+ x1 = x[..., : self.dims // 2]
138
+ x2 = x[..., self.dims // 2 : self.dims]
139
+ rx1 = x1 * costheta - x2 * sintheta
140
+ rx2 = x1 * sintheta + x2 * costheta
141
+
142
+ if self.dims < x.shape[-1]:
143
+ rx = tf.concat([rx1, rx2, x[..., self.dims :]], axis=-1)
144
+ else:
145
+ rx = tf.concat([rx1, rx2], axis=-1)
146
+
147
+ return rx
148
+
149
+ def _compute_traditional_rope(self, costheta, sintheta, x):
150
+ x1 = x[..., ::2]
151
+ x2 = x[..., 1::2]
152
+ rx1 = x1 * costheta - x2 * sintheta
153
+ rx2 = x1 * sintheta + x2 * costheta
154
+
155
+ if self.dims < x.shape[-1]:
156
+ raise NotImplementedError(
157
+ "RoPE doesn't implement partial traditional application"
158
+ )
159
+
160
+ rx = tf.concat([rx1[..., None], rx2[..., None]], axis=-1)
161
+
162
+ return rx
163
+
164
+ def __call__(self, x, offset: int = 0):
165
+ shape = x.shape
166
+ x = tf.reshape(x, (-1, shape[-2], shape[-1]))
167
+ N = x.shape[1] + offset
168
+ costheta, sintheta = RoPE.create_cos_sin_theta(
169
+ N, self.dims, offset=offset, base=self.base, dtype=x.dtype
170
+ )
171
+
172
+ rope = (
173
+ self._compute_traditional_rope if self.traditional else self._compute_rope
174
+ )
175
+ rx = rope(costheta, sintheta, x)
176
+
177
+ return tf.reshape(rx, shape)
178
+
179
+ @staticmethod
180
+ def create_cos_sin_theta(
181
+ N: int,
182
+ D: int,
183
+ offset: int = 0,
184
+ base: float = 10000,
185
+ dtype=tf.float32,
186
+ ):
187
+ D = D // 2
188
+ positions = tf.range(offset, N, dtype=dtype)
189
+ freqs = tf.math.exp(
190
+ -tf.range(0, D, dtype=dtype) * (tf.math.log(base) / D)
191
+ )
192
+ theta = tf.reshape(positions, (-1, 1)) * tf.reshape(freqs, (1, -1))
193
+ costheta = tf.math.cos(theta)
194
+ sintheta = tf.math.sin(theta)
195
+
196
+ return costheta, sintheta