Upload Whisper.py
Browse files- Whisper.py +263 -0
Whisper.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras.layers import Dense,Conv1d,ZeroPadding1D,LayerNormalization
|
3 |
+
from tensorflow.keras import Model
|
4 |
+
import base64
|
5 |
+
import gzip
|
6 |
+
import numpy as np
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
|
10 |
+
class ModelDimensions:
|
11 |
+
n_mels: int
|
12 |
+
n_audio_ctx: int
|
13 |
+
n_audio_state: int
|
14 |
+
n_audio_head: int
|
15 |
+
n_audio_layer: int
|
16 |
+
n_vocab: int
|
17 |
+
n_text_ctx: int
|
18 |
+
n_text_state: int
|
19 |
+
n_text_head: int
|
20 |
+
n_text_layer: int
|
21 |
+
|
22 |
+
|
23 |
+
def sinusoids(length, channels, max_timescale=10000):
|
24 |
+
"""Returns sinusoids for positional embedding"""
|
25 |
+
assert channels % 2 == 0
|
26 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
27 |
+
inv_timescales = tf.math.exp(-log_timescale_increment * np.arange(channels // 2))
|
28 |
+
scaled_time = np.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
29 |
+
return tf.concat([tf.math.sin(scaled_time), tf.math.cos(scaled_time)], axis=1)
|
30 |
+
|
31 |
+
|
32 |
+
class LayerNorm:
|
33 |
+
def __init__(self, n_state):
|
34 |
+
self.layer_norm = LayerNormalization
|
35 |
+
|
36 |
+
def __call__(self, x):
|
37 |
+
return tf.cast(self.layer_norm(tf.cast(x, 'float32')), x.dtype)
|
38 |
+
|
39 |
+
|
40 |
+
class MultiHeadAttention:
|
41 |
+
def __init__(self, n_state: int, n_head: int):
|
42 |
+
self.n_head = n_head
|
43 |
+
self.query = Dense(n_state)
|
44 |
+
self.key = Dense(n_state, use_bias=False)
|
45 |
+
self.value = Dense(n_state)
|
46 |
+
self.out = Dense(n_state)
|
47 |
+
|
48 |
+
def __call__(
|
49 |
+
self,
|
50 |
+
x,
|
51 |
+
xa=None,
|
52 |
+
mask=None,
|
53 |
+
kv_cache=None,
|
54 |
+
):
|
55 |
+
q = self.query(x)
|
56 |
+
|
57 |
+
if xa is None:
|
58 |
+
k = self.key(x)
|
59 |
+
v = self.value(x)
|
60 |
+
if kv_cache is not None:
|
61 |
+
k = tf.concat([kv_cache[0], k], axis=1)
|
62 |
+
v = tf.concat([kv_cache[1], v], axis=1)
|
63 |
+
elif kv_cache is None:
|
64 |
+
k = self.key(xa)
|
65 |
+
v = self.value(xa)
|
66 |
+
else:
|
67 |
+
k, v = kv_cache
|
68 |
+
|
69 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
70 |
+
return self.out(wv), (k, v), qk
|
71 |
+
|
72 |
+
def qkv_attention(self, q, k, v, mask=None):
|
73 |
+
n_batch, n_ctx, n_state = q.shape
|
74 |
+
scale = (n_state // self.n_head) ** -0.25
|
75 |
+
q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale
|
76 |
+
k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale
|
77 |
+
v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3)
|
78 |
+
|
79 |
+
qk = tf.matmul(q, k)
|
80 |
+
if mask is not None:
|
81 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
82 |
+
qk = tf.cast(qk, tf.float32)
|
83 |
+
|
84 |
+
w = tf.cast(tf.nn.softmax(qk, axis=-1), q.dtype)
|
85 |
+
out = tf.transpose(tf.matmul(w, v), (0, 2, 1, 3))
|
86 |
+
out = tf.reshape(out, (n_batch, n_ctx, n_state))
|
87 |
+
return out, qk
|
88 |
+
|
89 |
+
|
90 |
+
class ResidualAttentionBlock:
|
91 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
92 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
93 |
+
self.attn_ln = LayerNorm(n_state)
|
94 |
+
|
95 |
+
self.cross_attn = (
|
96 |
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
97 |
+
)
|
98 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
99 |
+
|
100 |
+
n_mlp = n_state * 4
|
101 |
+
self.mlp1 = Dense(n_mlp)
|
102 |
+
self.mlp2 = Dense(n_state)
|
103 |
+
self.mlp_ln = LayerNorm(n_state)
|
104 |
+
|
105 |
+
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
106 |
+
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
107 |
+
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
108 |
+
x += y
|
109 |
+
cross_qk = None
|
110 |
+
if self.cross_attn:
|
111 |
+
y, cross_kv, cross_qk = self.cross_attn(
|
112 |
+
self.cross_attn_ln(x), xa, kv_cache=cross_kv
|
113 |
+
)
|
114 |
+
x += y
|
115 |
+
x = x + tf.cast(self.mlp2(tf.nn.gelu(self.mlp1(self.mlp_ln(x))), x.dtype))
|
116 |
+
return x, (kv, cross_kv), cross_qk
|
117 |
+
|
118 |
+
|
119 |
+
class AudioEncoder:
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
n_mels: int,
|
123 |
+
n_ctx: int,
|
124 |
+
n_state: int,
|
125 |
+
n_head: int,
|
126 |
+
n_layer: int,
|
127 |
+
dtype = tf.float16,
|
128 |
+
):
|
129 |
+
self.zeropadding1d1 = ZeroPadding1D(padding=1)
|
130 |
+
self.conv1 = Conv1d(filters=n_state, kernel_size=3)
|
131 |
+
self.zeropadding1d2 = ZeroPadding1D(padding=1)
|
132 |
+
self.conv2 = Conv1d(filters=n_state, kernel_size=3, strides=2)
|
133 |
+
self._positional_embedding = tf.cast(sinusoids(n_ctx, n_state), dtype)
|
134 |
+
|
135 |
+
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
136 |
+
self.ln_post = LayerNorm(n_state)
|
137 |
+
|
138 |
+
def __call__(self, x):
|
139 |
+
x = self.zeropadding1d1(x)
|
140 |
+
x = tf.cast(tf.nn.gelu(self.conv1(x)), x.dtype)
|
141 |
+
x = self.zeropadding1d2(x)
|
142 |
+
x = tf.cast(tf.nn.gelu(self.conv2(x)), x.dtype)
|
143 |
+
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
144 |
+
x = x + self._positional_embedding
|
145 |
+
|
146 |
+
for block in self.blocks:
|
147 |
+
x, _, _ = block(x)
|
148 |
+
|
149 |
+
x = self.ln_post(x)
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class TextDecoder:
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
n_vocab: int,
|
157 |
+
n_ctx: int,
|
158 |
+
n_state: int,
|
159 |
+
n_head: int,
|
160 |
+
n_layer: int,
|
161 |
+
dtype = tf.float16,
|
162 |
+
):
|
163 |
+
self.token_embedding = tf.Variable(tf.random.normal([n_vocab, n_state]))
|
164 |
+
self.positional_embedding = tf.Variable(tf.zeros([n_ctx, n_state]))
|
165 |
+
|
166 |
+
self.blocks = [
|
167 |
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
168 |
+
for _ in range(n_layer)
|
169 |
+
]
|
170 |
+
self.ln = LayerNorm(n_state)
|
171 |
+
self._mask = tf.fill((n_ctx, n_ctx), float("-inf"))
|
172 |
+
self._mask = tf.linalg.band_part(self._mask, 0, -1)
|
173 |
+
self._mask = tf.linalg.set_diag(self._mask, tf.zeros(n_ctx))
|
174 |
+
self._mask = tf.cast(self._mask, dtype)
|
175 |
+
|
176 |
+
def __call__(self, x, xa, kv_cache=None):
|
177 |
+
"""
|
178 |
+
x : shape = (batch_size, <= n_ctx)
|
179 |
+
the text tokens
|
180 |
+
xa : shape = (batch_size, n_audio_ctx, n_audio_state)
|
181 |
+
the encoded audio features to be attended on
|
182 |
+
"""
|
183 |
+
offset = kv_cache[0][0][0].shape[1] if kv_cache else 0
|
184 |
+
x = (
|
185 |
+
tf.gather(self.token_embedding, x)
|
186 |
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
187 |
+
)
|
188 |
+
|
189 |
+
if kv_cache is None:
|
190 |
+
kv_cache = [None] * len(self.blocks)
|
191 |
+
cross_qk = [None] * len(self.blocks)
|
192 |
+
for e, block in enumerate(self.blocks):
|
193 |
+
x, kv_cache[e], cross_qk[e] = block(
|
194 |
+
x, xa, mask=self._mask, kv_cache=kv_cache[e]
|
195 |
+
)
|
196 |
+
|
197 |
+
x = self.ln(x)
|
198 |
+
return tf.matmul(x, tf.transpose(self.token_embedding)), kv_cache, cross_qk
|
199 |
+
|
200 |
+
|
201 |
+
class Whisper(Model):
|
202 |
+
def __init__(self, dims: ModelDimensions, dtype = tf.float16):
|
203 |
+
super(Whisper, self).__init__()
|
204 |
+
self.dims = dims
|
205 |
+
self.encoder = AudioEncoder(
|
206 |
+
self.dims.n_mels,
|
207 |
+
self.dims.n_audio_ctx,
|
208 |
+
self.dims.n_audio_state,
|
209 |
+
self.dims.n_audio_head,
|
210 |
+
self.dims.n_audio_layer,
|
211 |
+
dtype,
|
212 |
+
)
|
213 |
+
self.decoder = TextDecoder(
|
214 |
+
self.dims.n_vocab,
|
215 |
+
self.dims.n_text_ctx,
|
216 |
+
self.dims.n_text_state,
|
217 |
+
self.dims.n_text_head,
|
218 |
+
self.dims.n_text_layer,
|
219 |
+
dtype,
|
220 |
+
)
|
221 |
+
# use the last half among the decoder layers for time alignment by default;
|
222 |
+
# to use a specific set of heads, see `set_alignment_heads()` below.
|
223 |
+
all_heads = np.zeros(
|
224 |
+
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
|
225 |
+
)
|
226 |
+
all_heads[self.dims.n_text_layer // 2 :] = True
|
227 |
+
self.alignment_heads = tf.transpose(tf.cast(tf.where(all_heads != 0), dtype=tf.int32))
|
228 |
+
|
229 |
+
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
230 |
+
if isinstance(dump, np.ndarray):
|
231 |
+
self.alignment_heads = tf.convert_to_tensor(dump)
|
232 |
+
elif isinstance(dump, bytes):
|
233 |
+
array = np.frombuffer(
|
234 |
+
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
235 |
+
).copy()
|
236 |
+
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
237 |
+
self.alignment_heads = tf.transpose(tf.cast(tf.where(mask != 0), dtype=tf.int32))
|
238 |
+
else:
|
239 |
+
raise ValueError(
|
240 |
+
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
241 |
+
" alignment_head information"
|
242 |
+
)
|
243 |
+
|
244 |
+
def embed_audio(self, mel):
|
245 |
+
return self.encoder(mel)
|
246 |
+
|
247 |
+
def logits(self, tokens, audio_features):
|
248 |
+
return self.decoder(tokens, audio_features)[0]
|
249 |
+
|
250 |
+
def forward_with_cross_qk(self, mel, tokens):
|
251 |
+
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
|
252 |
+
return logits, cross_qk
|
253 |
+
|
254 |
+
def __call__(self, mel, tokens):
|
255 |
+
return self.decoder(tokens, self.encoder(mel))[0]
|
256 |
+
|
257 |
+
@property
|
258 |
+
def is_multilingual(self):
|
259 |
+
return self.dims.n_vocab >= 51865
|
260 |
+
|
261 |
+
@property
|
262 |
+
def num_languages(self):
|
263 |
+
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|