bn_multi_tribe_mt / src /seq2seq_trainer.py
MasumBhuiyan's picture
Seq2Seq model implemented
c5dc1d4
raw
history blame contribute delete
No virus
1.1 kB
from pipes import utils
from pipes import const
from pipes import models
from pipes.data import Dataset
import tensorflow as tf
if __name__ == "__main__":
input_lang = 'gr'
output_lang = 'bn'
dataset_object = Dataset([input_lang, output_lang])
dataset_object.pack()
dataset_object.process()
train_ds, val_ds = dataset_object.pull()
dataset_dict = dataset_object.get_dict()
model_object = models.Seq2Seq(
input_vocab_size=dataset_dict[input_lang]["vocab_size"],
output_vocab_size=dataset_dict[output_lang]["vocab_size"],
embedding_dim=256,
hidden_units=512
)
model_object.build()
model = model_object.get()
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy', 'val_accuracy'],
)
history = model.fit(
train_ds.repeat(),
epochs=10,
steps_per_epoch=100,
validation_steps=20,
validation_data=val_ds,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)