import argparse import tensorflow as tf import model from dataset import get_dataset, preprocess_sentence class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): def __init__(self, d_model: int, warmup_steps: int = 4000): super(CustomSchedule, self).__init__() self.d_model = tf.cast(d_model, dtype=tf.float32) self.warmup_steps = warmup_steps def __call__(self, step): arg1 = tf.math.rsqrt(step) arg2 = step * self.warmup_steps**-1.5 return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) def inference(hparams, chatbot, tokenizer, sentence): sentence = preprocess_sentence(sentence) sentence = tf.expand_dims( hparams.start_token + tokenizer.encode(sentence) + hparams.end_token, axis=0 ) output = tf.expand_dims(hparams.start_token, 0) for _ in range(hparams.max_length): predictions = chatbot(inputs=[sentence, output], training=False) predictions = predictions[:, -1:, :] predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) if tf.equal(predicted_id, hparams.end_token[0]): break output = tf.concat([output, predicted_id], axis=-1) return tf.squeeze(output, axis=0) def predict(hparams, chatbot, tokenizer, sentence): prediction = inference(hparams, chatbot, tokenizer, sentence) predicted_sentence = tokenizer.decode( [i for i in prediction if i < tokenizer.vocab_size] ) return predicted_sentence def evaluate(hparams, chatbot, tokenizer): print("\nDeğerlendir") sentence = "Merhaba nasılsın?" output = predict(hparams, chatbot, tokenizer, sentence) print(f"input: {sentence}\noutput: {output}") sentence = "Sence de gökyüzü çok güzel değil mi?" output = predict(hparams, chatbot, tokenizer, sentence) print(f"\ninput: {sentence}\noutput: {output}") sentence = "Sanırım uzaklara gideceğim." for _ in range(5): output = predict(hparams, chatbot, tokenizer, sentence) print(f"\ninput: {sentence}\noutput: {output}") sentence = output def main(hparams): tf.keras.utils.set_random_seed(1234) data, token = get_dataset(hparams) chatbot = model.transformer(hparams) optimizer = tf.keras.optimizers.Adam( CustomSchedule(d_model=hparams.d_model), beta_1=0.9, beta_2=0.98, epsilon=1e-9 ) cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction="none" ) def loss_function(y_true, y_pred): y_true = tf.reshape(y_true, shape=(-1, hparams.max_length - 1)) loss = cross_entropy(y_true, y_pred) mask = tf.cast(tf.not_equal(y_true, 0), dtype=tf.float32) loss = tf.multiply(loss, mask) return tf.reduce_mean(loss) def accuracy(y_true, y_pred): y_true = tf.reshape(y_true, shape=(-1, hparams.max_length - 1)) return tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred) chatbot.compile(optimizer, loss=loss_function, metrics=[accuracy]) chatbot.fit(data, epochs=hparams.epochs) print(f"\nmodel {hparams.save_model}'a kaydediliyor...") tf.keras.models.save_model( chatbot, filepath=hparams.save_model, include_optimizer=False ) print( f"\nclear TensorFlow backend session and load model f rom {hparams.save_model}..." ) del chatbot tf.keras.backend.clear_session() chatbot = tf.keras.models.load_model( hparams.save_model, custom_objects={ "PositionalEncoding": model.PositionalEncoding, "MultiHeadAttention": model.MultiHeadAttention, }, compile=False, ) evaluate(hparams, chatbot, token) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--save_model", default="model.h5", type=str, help="path save the model" ) parser.add_argument( "--max_samples", default=25000, type=int, help="maximum number of conversation pairs to use", ) parser.add_argument( "--max_length", default=40, type=int, help="maximum sentence length" ) parser.add_argument("--batch_size", default=128, type=int) parser.add_argument("--num_layers", default=2, type=int) parser.add_argument("--num_units", default=512, type=int) parser.add_argument("--d_model", default=512, type=int) parser.add_argument("--num_heads", default=8, type=int) parser.add_argument("--dropout", default=0.1, type=float) parser.add_argument("--activation", default="relu", type=str) parser.add_argument("--epochs", default=70, type=int) main(parser.parse_args())