|
import tensorflow as tf |
|
from tensorflow.keras.layers import Dense,Conv1d,ZeroPadding1D,LayerNormalization |
|
from tensorflow.keras import Model |
|
import base64 |
|
import gzip |
|
import numpy as np |
|
from typing import Union |
|
|
|
|
|
class ModelDimensions: |
|
n_mels: int |
|
n_audio_ctx: int |
|
n_audio_state: int |
|
n_audio_head: int |
|
n_audio_layer: int |
|
n_vocab: int |
|
n_text_ctx: int |
|
n_text_state: int |
|
n_text_head: int |
|
n_text_layer: int |
|
|
|
|
|
def sinusoids(length, channels, max_timescale=10000): |
|
"""Returns sinusoids for positional embedding""" |
|
assert channels % 2 == 0 |
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
|
inv_timescales = tf.math.exp(-log_timescale_increment * np.arange(channels // 2)) |
|
scaled_time = np.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
|
return tf.concat([tf.math.sin(scaled_time), tf.math.cos(scaled_time)], axis=1) |
|
|
|
|
|
class LayerNorm: |
|
def __init__(self, n_state): |
|
self.layer_norm = LayerNormalization |
|
|
|
def __call__(self, x): |
|
return tf.cast(self.layer_norm(tf.cast(x, 'float32')), x.dtype) |
|
|
|
|
|
class MultiHeadAttention: |
|
def __init__(self, n_state: int, n_head: int): |
|
self.n_head = n_head |
|
self.query = Dense(n_state) |
|
self.key = Dense(n_state, use_bias=False) |
|
self.value = Dense(n_state) |
|
self.out = Dense(n_state) |
|
|
|
def __call__( |
|
self, |
|
x, |
|
xa=None, |
|
mask=None, |
|
kv_cache=None, |
|
): |
|
q = self.query(x) |
|
|
|
if xa is None: |
|
k = self.key(x) |
|
v = self.value(x) |
|
if kv_cache is not None: |
|
k = tf.concat([kv_cache[0], k], axis=1) |
|
v = tf.concat([kv_cache[1], v], axis=1) |
|
elif kv_cache is None: |
|
k = self.key(xa) |
|
v = self.value(xa) |
|
else: |
|
k, v = kv_cache |
|
|
|
wv, qk = self.qkv_attention(q, k, v, mask) |
|
return self.out(wv), (k, v), qk |
|
|
|
def qkv_attention(self, q, k, v, mask=None): |
|
n_batch, n_ctx, n_state = q.shape |
|
scale = (n_state // self.n_head) ** -0.25 |
|
q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale |
|
k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale |
|
v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) |
|
|
|
qk = tf.matmul(q, k) |
|
if mask is not None: |
|
qk = qk + mask[:n_ctx, :n_ctx] |
|
qk = tf.cast(qk, tf.float32) |
|
|
|
w = tf.cast(tf.nn.softmax(qk, axis=-1), q.dtype) |
|
out = tf.transpose(tf.matmul(w, v), (0, 2, 1, 3)) |
|
out = tf.reshape(out, (n_batch, n_ctx, n_state)) |
|
return out, qk |
|
|
|
|
|
class ResidualAttentionBlock: |
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): |
|
self.attn = MultiHeadAttention(n_state, n_head) |
|
self.attn_ln = LayerNorm(n_state) |
|
|
|
self.cross_attn = ( |
|
MultiHeadAttention(n_state, n_head) if cross_attention else None |
|
) |
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
|
|
|
n_mlp = n_state * 4 |
|
self.mlp1 = Dense(n_mlp) |
|
self.mlp2 = Dense(n_state) |
|
self.mlp_ln = LayerNorm(n_state) |
|
|
|
def __call__(self, x, xa=None, mask=None, kv_cache=None): |
|
kv, cross_kv = kv_cache if kv_cache else (None, None) |
|
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) |
|
x += y |
|
cross_qk = None |
|
if self.cross_attn: |
|
y, cross_kv, cross_qk = self.cross_attn( |
|
self.cross_attn_ln(x), xa, kv_cache=cross_kv |
|
) |
|
x += y |
|
x = x + tf.cast(self.mlp2(tf.nn.gelu(self.mlp1(self.mlp_ln(x))), x.dtype)) |
|
return x, (kv, cross_kv), cross_qk |
|
|
|
|
|
class AudioEncoder: |
|
def __init__( |
|
self, |
|
n_mels: int, |
|
n_ctx: int, |
|
n_state: int, |
|
n_head: int, |
|
n_layer: int, |
|
dtype = tf.float16, |
|
): |
|
self.zeropadding1d1 = ZeroPadding1D(padding=1) |
|
self.conv1 = Conv1d(filters=n_state, kernel_size=3) |
|
self.zeropadding1d2 = ZeroPadding1D(padding=1) |
|
self.conv2 = Conv1d(filters=n_state, kernel_size=3, strides=2) |
|
self._positional_embedding = tf.cast(sinusoids(n_ctx, n_state), dtype) |
|
|
|
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] |
|
self.ln_post = LayerNorm(n_state) |
|
|
|
def __call__(self, x): |
|
x = self.zeropadding1d1(x) |
|
x = tf.cast(tf.nn.gelu(self.conv1(x)), x.dtype) |
|
x = self.zeropadding1d2(x) |
|
x = tf.cast(tf.nn.gelu(self.conv2(x)), x.dtype) |
|
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" |
|
x = x + self._positional_embedding |
|
|
|
for block in self.blocks: |
|
x, _, _ = block(x) |
|
|
|
x = self.ln_post(x) |
|
return x |
|
|
|
|
|
class TextDecoder(tf.keras.layers.Layer): |
|
def __init__( |
|
self, |
|
n_vocab: int, |
|
n_ctx: int, |
|
n_state: int, |
|
n_head: int, |
|
n_layer: int, |
|
dtype = tf.float16, |
|
): |
|
self.token_embedding = self.add_weight( |
|
name='token_embedding', |
|
shape=[self.n_vocab, self.n_state], |
|
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), |
|
trainable=True |
|
) |
|
self.positional_embedding = self.add_weight( |
|
name='positional_embedding', |
|
shape=[self.n_ctx, self.n_state], |
|
initializer=tf.keras.initializers.Zeros(), |
|
trainable=True |
|
) |
|
|
|
self.blocks = [ |
|
ResidualAttentionBlock(n_state, n_head, cross_attention=True) |
|
for _ in range(n_layer) |
|
] |
|
self.ln = LayerNorm(n_state) |
|
self._mask = tf.fill((n_ctx, n_ctx), float("-inf")) |
|
self._mask = tf.linalg.band_part(self._mask, 0, -1) |
|
self._mask = tf.linalg.set_diag(self._mask, tf.zeros(n_ctx)) |
|
self._mask = tf.cast(self._mask, dtype) |
|
|
|
def __call__(self, x, xa, kv_cache=None): |
|
""" |
|
x : shape = (batch_size, <= n_ctx) |
|
the text tokens |
|
xa : shape = (batch_size, n_audio_ctx, n_audio_state) |
|
the encoded audio features to be attended on |
|
""" |
|
offset = kv_cache[0][0][0].shape[1] if kv_cache else 0 |
|
x = ( |
|
tf.gather(self.token_embedding, x) |
|
+ self.positional_embedding[offset : offset + x.shape[-1]] |
|
) |
|
|
|
if kv_cache is None: |
|
kv_cache = [None] * len(self.blocks) |
|
cross_qk = [None] * len(self.blocks) |
|
for e, block in enumerate(self.blocks): |
|
x, kv_cache[e], cross_qk[e] = block( |
|
x, xa, mask=self._mask, kv_cache=kv_cache[e] |
|
) |
|
|
|
x = self.ln(x) |
|
return tf.matmul(x, tf.transpose(self.token_embedding)), kv_cache, cross_qk |
|
|
|
|
|
class Whisper(Model): |
|
def __init__(self, dims: ModelDimensions, dtype = tf.float16): |
|
super(Whisper, self).__init__() |
|
self.dims = dims |
|
self.encoder = AudioEncoder( |
|
self.dims.n_mels, |
|
self.dims.n_audio_ctx, |
|
self.dims.n_audio_state, |
|
self.dims.n_audio_head, |
|
self.dims.n_audio_layer, |
|
dtype, |
|
) |
|
self.decoder = TextDecoder( |
|
self.dims.n_vocab, |
|
self.dims.n_text_ctx, |
|
self.dims.n_text_state, |
|
self.dims.n_text_head, |
|
self.dims.n_text_layer, |
|
dtype, |
|
) |
|
|
|
|
|
all_heads = np.zeros( |
|
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool |
|
) |
|
all_heads[self.dims.n_text_layer // 2 :] = True |
|
self.alignment_heads = tf.transpose(tf.cast(tf.where(all_heads != 0), dtype=tf.int32)) |
|
|
|
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): |
|
if isinstance(dump, np.ndarray): |
|
self.alignment_heads = tf.convert_to_tensor(dump) |
|
elif isinstance(dump, bytes): |
|
array = np.frombuffer( |
|
gzip.decompress(base64.b85decode(dump)), dtype=bool |
|
).copy() |
|
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) |
|
self.alignment_heads = tf.transpose(tf.cast(tf.where(mask != 0), dtype=tf.int32)) |
|
else: |
|
raise ValueError( |
|
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing" |
|
" alignment_head information" |
|
) |
|
|
|
def embed_audio(self, mel): |
|
return self.encoder(mel) |
|
|
|
def logits(self, tokens, audio_features): |
|
return self.decoder(tokens, audio_features)[0] |
|
|
|
def forward_with_cross_qk(self, mel, tokens): |
|
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel)) |
|
return logits, cross_qk |
|
|
|
def __call__(self, mel, tokens): |
|
return self.decoder(tokens, self.encoder(mel))[0] |
|
|
|
@property |
|
def is_multilingual(self): |
|
return self.dims.n_vocab >= 51865 |
|
|
|
@property |
|
def num_languages(self): |
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual) |