NoteDance commited on
Commit
36504b4
1 Parent(s): cf27903

Update Whisper.py

Browse files
Files changed (1) hide show
  1. Whisper.py +13 -3
Whisper.py CHANGED
@@ -150,7 +150,7 @@ class AudioEncoder:
150
  return x
151
 
152
 
153
- class TextDecoder:
154
  def __init__(
155
  self,
156
  n_vocab: int,
@@ -160,8 +160,18 @@ class TextDecoder:
160
  n_layer: int,
161
  dtype = tf.float16,
162
  ):
163
- self.token_embedding = tf.Variable(tf.random.normal([n_vocab, n_state]))
164
- self.positional_embedding = tf.Variable(tf.zeros([n_ctx, n_state]))
 
 
 
 
 
 
 
 
 
 
165
 
166
  self.blocks = [
167
  ResidualAttentionBlock(n_state, n_head, cross_attention=True)
 
150
  return x
151
 
152
 
153
+ class TextDecoder(tf.keras.layers.Layer):
154
  def __init__(
155
  self,
156
  n_vocab: int,
 
160
  n_layer: int,
161
  dtype = tf.float16,
162
  ):
163
+ self.token_embedding = self.add_weight(
164
+ name='token_embedding',
165
+ shape=[self.n_vocab, self.n_state],
166
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
167
+ trainable=True
168
+ )
169
+ self.positional_embedding = self.add_weight(
170
+ name='positional_embedding',
171
+ shape=[self.n_ctx, self.n_state],
172
+ initializer=tf.keras.initializers.Zeros(), # 初始化为全零
173
+ trainable=True
174
+ )
175
 
176
  self.blocks = [
177
  ResidualAttentionBlock(n_state, n_head, cross_attention=True)