TomRB22 commited on
Commit
ea295a1
1 Parent(s): 3bf476b

Made in-script documentation of model.py more readable

Browse files
Files changed (1) hide show
  1. model.py +27 -15
model.py CHANGED
@@ -160,13 +160,19 @@ class VAE(tf.keras.Model):
160
  Get a "song map" and make a forward pass through the encoder, in order
161
  to return the latent representation and the distribution's parameters.
162
 
163
- Parameters:
164
- x_input (tf.Tensor): Song map to be encoded by the VAE.
165
-
166
- Returns:
167
- tf.Tensor: The parameters of the distribution which encode the song
168
- (mu, sd) and a sampled latent representation from this
169
- distribution (z_sample).
 
 
 
 
 
 
170
  """
171
 
172
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
@@ -178,15 +184,21 @@ class VAE(tf.keras.Model):
178
  """
179
  Decode a latent representation of a song.
180
 
181
- Parameters:
182
- z_sample (tf.Tensor): Song encoding outputed by the encoder. If
183
- None, this sampling is done over an
184
- unit Gaussian distribution.
 
 
185
 
186
- Returns:
187
- tf.Tensor: Song map corresponding to the encoding.
 
 
188
  """
189
 
190
  if z_sample == None:
191
- z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
192
- return self.decoder(z_sample)
 
 
 
160
  Get a "song map" and make a forward pass through the encoder, in order
161
  to return the latent representation and the distribution's parameters.
162
 
163
+ Parameters
164
+ ----------
165
+ x_input : tf.Tensor
166
+ Song map to be encoded by the VAE.
167
+
168
+ Returns
169
+ -------
170
+ z_sample: tf.Tensor
171
+ A sampled latent representation from the distribution which encodes the song.
172
+ mu: tf.Tensor
173
+ The mean parameter of the distribution.
174
+ sd: tf.Tensor
175
+ The standard deviation parameter of the distribution.
176
  """
177
 
178
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
 
184
  """
185
  Decode a latent representation of a song.
186
 
187
+ Parameters
188
+ ----------
189
+ z_sample : tf.Tensor
190
+
191
+ Song encoding outputed by the encoder.
192
+ Default ``None``, for which the sampling is done over an unit Gaussian distribution.
193
 
194
+ Returns
195
+ -------
196
+ song_map: tf.Tensor
197
+ Song map corresponding to the encoding.
198
  """
199
 
200
  if z_sample == None:
201
+ z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0
202
+
203
+ song_map = self.decoder(z_sample)
204
+ return song_map