include-full-model

#2
by camocazi - opened
No description provided.

Thanks for your great work @camocazi

vumichien changed pull request status to open
vumichien changed pull request status to merged

Currently what's included in the repo is only the decoder part of the network, which is not very useful on its own.

from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("keras-io/drug-molecule-generation-with-VAE")
model.summary()

Model: "decoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_6 (InputLayer)           [(None, 435)]        0           []                               
                                                                                                  
 dense_8 (Dense)                (None, 128)          55808       ['input_6[0][0]']                
                                                                                                  
 dropout_5 (Dropout)            (None, 128)          0           ['dense_8[0][0]']                
                                                                                                  
 dense_9 (Dense)                (None, 256)          33024       ['dropout_5[0][0]']              
                                                                                                  
 dropout_6 (Dropout)            (None, 256)          0           ['dense_9[0][0]']                
                                                                                                  
 dense_10 (Dense)               (None, 512)          131584      ['dropout_6[0][0]']              
                                                                                                  
 dropout_7 (Dropout)            (None, 512)          0           ['dense_10[0][0]']               
                                                                                                  
 dense_11 (Dense)               (None, 72000)        36936000    ['dropout_7[0][0]']              
                                                                                                  
 reshape_2 (Reshape)            (None, 5, 120, 120)  0           ['dense_11[0][0]']               
                                                                                                  
 tf.compat.v1.transpose_1 (TFOp  (None, 5, 120, 120)  0          ['reshape_2[0][0]']              
 Lambda)                                                                                          
                                                                                                  
 tf.__operators__.add_1 (TFOpLa  (None, 5, 120, 120)  0          ['reshape_2[0][0]',              
 mbda)                                                            'tf.compat.v1.transpose_1[0][0]'
                                                                 ]                                
                                                                                                  
 dense_12 (Dense)               (None, 1320)         677160      ['dropout_7[0][0]']              
                                                                                                  
 tf.math.truediv_1 (TFOpLambda)  (None, 5, 120, 120)  0          ['tf.__operators__.add_1[0][0]'] 
                                                                                                  
 reshape_3 (Reshape)            (None, 120, 11)      0           ['dense_12[0][0]']               
                                                                                                  
 softmax_2 (Softmax)            (None, 5, 120, 120)  0           ['tf.math.truediv_1[0][0]']      
                                                                                                  
 softmax_3 (Softmax)            (None, 120, 11)      0           ['reshape_3[0][0]']              
                                                                                                  
==================================================================================================
Total params: 37,833,576
Trainable params: 37,833,576
Non-trainable params: 0

This PR includes the encoder, decoder, and sampling layer in the model object (and includes their weights after training with the published code).

model.summary()

Model: "molecule_generator_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 encoder (Functional)        [(None, 435),             451925    
                              (None, 435)]                       
                                                                 
 decoder (Functional)        [(None, 5, 120, 120),     37833576  
                              (None, 120, 11)]                   
                                                                 
 dense_13 (Dense)            multiple                  436       
                                                                 
=================================================================
Total params: 38,285,941
Trainable params: 38,285,937
Non-trainable params: 4

Also updates the diagrams of the encoder and decoder graphs.

Super! Thank you for clearly describing

Sign up or log in to comment