File size: 2,670 Bytes
b3c2eb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import time
from tensorflow.keras import Sequential
from tensorflow.keras.models import model_from_json
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.initializers import RandomNormal
from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer
class LSTMChem(object):
def __init__(self, config, session='train'):
assert session in ['train', 'generate', 'finetune'], \
'one of {train, generate, finetune}'
self.config = config
self.session = session
self.model = None
if self.session == 'train':
self.build_model()
else:
self.model = self.load(self.config.model_arch_filename,
self.config.model_weight_filename)
def build_model(self):
st = SmilesTokenizer()
n_table = len(st.table)
weight_init = RandomNormal(mean=0.0,
stddev=0.05,
seed=self.config.seed)
self.model = Sequential()
self.model.add(
LSTM(units=self.config.units,
input_shape=(None, n_table),
return_sequences=True,
kernel_initializer=weight_init,
dropout=0.3))
self.model.add(
LSTM(units=self.config.units,
input_shape=(None, n_table),
return_sequences=True,
kernel_initializer=weight_init,
dropout=0.5))
self.model.add(
Dense(units=n_table,
activation='softmax',
kernel_initializer=weight_init))
arch = self.model.to_json(indent=2)
self.config.model_arch_filename = os.path.join(self.config.exp_dir,
'model_arch.json')
with open(self.config.model_arch_filename, 'w') as f:
f.write(arch)
self.model.compile(optimizer=self.config.optimizer,
loss='categorical_crossentropy')
def save(self, checkpoint_path):
assert self.model, 'You have to build the model first.'
print('Saving model ...')
self.model.save_weights(checkpoint_path)
print('model saved.')
def load(self, model_arch_file, checkpoint_file):
print(f'Loading model architecture from {model_arch_file} ...')
with open(model_arch_file) as f:
model = model_from_json(f.read())
print(f'Loading model checkpoint from {checkpoint_file} ...')
model.load_weights(checkpoint_file)
print('Loaded the Model.')
return model
|