NoteDance commited on
Commit
eea7db7
1 Parent(s): 94b3ab9

Upload Whisper.py

Browse files
Files changed (1) hide show
  1. Whisper.py +263 -0
Whisper.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Dense,Conv1d,ZeroPadding1D,LayerNormalization
3
+ from tensorflow.keras import Model
4
+ import base64
5
+ import gzip
6
+ import numpy as np
7
+ from typing import Union
8
+
9
+
10
+ class ModelDimensions:
11
+ n_mels: int
12
+ n_audio_ctx: int
13
+ n_audio_state: int
14
+ n_audio_head: int
15
+ n_audio_layer: int
16
+ n_vocab: int
17
+ n_text_ctx: int
18
+ n_text_state: int
19
+ n_text_head: int
20
+ n_text_layer: int
21
+
22
+
23
+ def sinusoids(length, channels, max_timescale=10000):
24
+ """Returns sinusoids for positional embedding"""
25
+ assert channels % 2 == 0
26
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
27
+ inv_timescales = tf.math.exp(-log_timescale_increment * np.arange(channels // 2))
28
+ scaled_time = np.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
29
+ return tf.concat([tf.math.sin(scaled_time), tf.math.cos(scaled_time)], axis=1)
30
+
31
+
32
+ class LayerNorm:
33
+ def __init__(self, n_state):
34
+ self.layer_norm = LayerNormalization
35
+
36
+ def __call__(self, x):
37
+ return tf.cast(self.layer_norm(tf.cast(x, 'float32')), x.dtype)
38
+
39
+
40
+ class MultiHeadAttention:
41
+ def __init__(self, n_state: int, n_head: int):
42
+ self.n_head = n_head
43
+ self.query = Dense(n_state)
44
+ self.key = Dense(n_state, use_bias=False)
45
+ self.value = Dense(n_state)
46
+ self.out = Dense(n_state)
47
+
48
+ def __call__(
49
+ self,
50
+ x,
51
+ xa=None,
52
+ mask=None,
53
+ kv_cache=None,
54
+ ):
55
+ q = self.query(x)
56
+
57
+ if xa is None:
58
+ k = self.key(x)
59
+ v = self.value(x)
60
+ if kv_cache is not None:
61
+ k = tf.concat([kv_cache[0], k], axis=1)
62
+ v = tf.concat([kv_cache[1], v], axis=1)
63
+ elif kv_cache is None:
64
+ k = self.key(xa)
65
+ v = self.value(xa)
66
+ else:
67
+ k, v = kv_cache
68
+
69
+ wv, qk = self.qkv_attention(q, k, v, mask)
70
+ return self.out(wv), (k, v), qk
71
+
72
+ def qkv_attention(self, q, k, v, mask=None):
73
+ n_batch, n_ctx, n_state = q.shape
74
+ scale = (n_state // self.n_head) ** -0.25
75
+ q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale
76
+ k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale
77
+ v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3)
78
+
79
+ qk = tf.matmul(q, k)
80
+ if mask is not None:
81
+ qk = qk + mask[:n_ctx, :n_ctx]
82
+ qk = tf.cast(qk, tf.float32)
83
+
84
+ w = tf.cast(tf.nn.softmax(qk, axis=-1), q.dtype)
85
+ out = tf.transpose(tf.matmul(w, v), (0, 2, 1, 3))
86
+ out = tf.reshape(out, (n_batch, n_ctx, n_state))
87
+ return out, qk
88
+
89
+
90
+ class ResidualAttentionBlock:
91
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
92
+ self.attn = MultiHeadAttention(n_state, n_head)
93
+ self.attn_ln = LayerNorm(n_state)
94
+
95
+ self.cross_attn = (
96
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
97
+ )
98
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
99
+
100
+ n_mlp = n_state * 4
101
+ self.mlp1 = Dense(n_mlp)
102
+ self.mlp2 = Dense(n_state)
103
+ self.mlp_ln = LayerNorm(n_state)
104
+
105
+ def __call__(self, x, xa=None, mask=None, kv_cache=None):
106
+ kv, cross_kv = kv_cache if kv_cache else (None, None)
107
+ y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
108
+ x += y
109
+ cross_qk = None
110
+ if self.cross_attn:
111
+ y, cross_kv, cross_qk = self.cross_attn(
112
+ self.cross_attn_ln(x), xa, kv_cache=cross_kv
113
+ )
114
+ x += y
115
+ x = x + tf.cast(self.mlp2(tf.nn.gelu(self.mlp1(self.mlp_ln(x))), x.dtype))
116
+ return x, (kv, cross_kv), cross_qk
117
+
118
+
119
+ class AudioEncoder:
120
+ def __init__(
121
+ self,
122
+ n_mels: int,
123
+ n_ctx: int,
124
+ n_state: int,
125
+ n_head: int,
126
+ n_layer: int,
127
+ dtype = tf.float16,
128
+ ):
129
+ self.zeropadding1d1 = ZeroPadding1D(padding=1)
130
+ self.conv1 = Conv1d(filters=n_state, kernel_size=3)
131
+ self.zeropadding1d2 = ZeroPadding1D(padding=1)
132
+ self.conv2 = Conv1d(filters=n_state, kernel_size=3, strides=2)
133
+ self._positional_embedding = tf.cast(sinusoids(n_ctx, n_state), dtype)
134
+
135
+ self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
136
+ self.ln_post = LayerNorm(n_state)
137
+
138
+ def __call__(self, x):
139
+ x = self.zeropadding1d1(x)
140
+ x = tf.cast(tf.nn.gelu(self.conv1(x)), x.dtype)
141
+ x = self.zeropadding1d2(x)
142
+ x = tf.cast(tf.nn.gelu(self.conv2(x)), x.dtype)
143
+ assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
144
+ x = x + self._positional_embedding
145
+
146
+ for block in self.blocks:
147
+ x, _, _ = block(x)
148
+
149
+ x = self.ln_post(x)
150
+ return x
151
+
152
+
153
+ class TextDecoder:
154
+ def __init__(
155
+ self,
156
+ n_vocab: int,
157
+ n_ctx: int,
158
+ n_state: int,
159
+ n_head: int,
160
+ n_layer: int,
161
+ dtype = tf.float16,
162
+ ):
163
+ self.token_embedding = tf.Variable(tf.random.normal([n_vocab, n_state]))
164
+ self.positional_embedding = tf.Variable(tf.zeros([n_ctx, n_state]))
165
+
166
+ self.blocks = [
167
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
168
+ for _ in range(n_layer)
169
+ ]
170
+ self.ln = LayerNorm(n_state)
171
+ self._mask = tf.fill((n_ctx, n_ctx), float("-inf"))
172
+ self._mask = tf.linalg.band_part(self._mask, 0, -1)
173
+ self._mask = tf.linalg.set_diag(self._mask, tf.zeros(n_ctx))
174
+ self._mask = tf.cast(self._mask, dtype)
175
+
176
+ def __call__(self, x, xa, kv_cache=None):
177
+ """
178
+ x : shape = (batch_size, <= n_ctx)
179
+ the text tokens
180
+ xa : shape = (batch_size, n_audio_ctx, n_audio_state)
181
+ the encoded audio features to be attended on
182
+ """
183
+ offset = kv_cache[0][0][0].shape[1] if kv_cache else 0
184
+ x = (
185
+ tf.gather(self.token_embedding, x)
186
+ + self.positional_embedding[offset : offset + x.shape[-1]]
187
+ )
188
+
189
+ if kv_cache is None:
190
+ kv_cache = [None] * len(self.blocks)
191
+ cross_qk = [None] * len(self.blocks)
192
+ for e, block in enumerate(self.blocks):
193
+ x, kv_cache[e], cross_qk[e] = block(
194
+ x, xa, mask=self._mask, kv_cache=kv_cache[e]
195
+ )
196
+
197
+ x = self.ln(x)
198
+ return tf.matmul(x, tf.transpose(self.token_embedding)), kv_cache, cross_qk
199
+
200
+
201
+ class Whisper(Model):
202
+ def __init__(self, dims: ModelDimensions, dtype = tf.float16):
203
+ super(Whisper, self).__init__()
204
+ self.dims = dims
205
+ self.encoder = AudioEncoder(
206
+ self.dims.n_mels,
207
+ self.dims.n_audio_ctx,
208
+ self.dims.n_audio_state,
209
+ self.dims.n_audio_head,
210
+ self.dims.n_audio_layer,
211
+ dtype,
212
+ )
213
+ self.decoder = TextDecoder(
214
+ self.dims.n_vocab,
215
+ self.dims.n_text_ctx,
216
+ self.dims.n_text_state,
217
+ self.dims.n_text_head,
218
+ self.dims.n_text_layer,
219
+ dtype,
220
+ )
221
+ # use the last half among the decoder layers for time alignment by default;
222
+ # to use a specific set of heads, see `set_alignment_heads()` below.
223
+ all_heads = np.zeros(
224
+ (self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
225
+ )
226
+ all_heads[self.dims.n_text_layer // 2 :] = True
227
+ self.alignment_heads = tf.transpose(tf.cast(tf.where(all_heads != 0), dtype=tf.int32))
228
+
229
+ def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
230
+ if isinstance(dump, np.ndarray):
231
+ self.alignment_heads = tf.convert_to_tensor(dump)
232
+ elif isinstance(dump, bytes):
233
+ array = np.frombuffer(
234
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
235
+ ).copy()
236
+ mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
237
+ self.alignment_heads = tf.transpose(tf.cast(tf.where(mask != 0), dtype=tf.int32))
238
+ else:
239
+ raise ValueError(
240
+ f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
241
+ " alignment_head information"
242
+ )
243
+
244
+ def embed_audio(self, mel):
245
+ return self.encoder(mel)
246
+
247
+ def logits(self, tokens, audio_features):
248
+ return self.decoder(tokens, audio_features)[0]
249
+
250
+ def forward_with_cross_qk(self, mel, tokens):
251
+ logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
252
+ return logits, cross_qk
253
+
254
+ def __call__(self, mel, tokens):
255
+ return self.decoder(tokens, self.encoder(mel))[0]
256
+
257
+ @property
258
+ def is_multilingual(self):
259
+ return self.dims.n_vocab >= 51865
260
+
261
+ @property
262
+ def num_languages(self):
263
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)