import tensorflow as tf from chatbot_model import RetrievalChatbot, ChatbotConfig from environment_setup import EnvironmentSetup from plotter import Plotter from logger_config import config_logger logger = config_logger(__name__) def inspect_tfrecord(tfrecord_file_path, num_examples=3): def parse_example(example_proto): feature_description = { 'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different 'positive_ids': tf.io.FixedLenFeature([512], tf.int64), 'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different } return tf.io.parse_single_example(example_proto, feature_description) dataset = tf.data.TFRecordDataset(tfrecord_file_path) dataset = dataset.map(parse_example) for i, example in enumerate(dataset.take(num_examples)): print(f"Example {i+1}:") print(f"Query IDs: {example['query_ids'].numpy()}") print(f"Positive IDs: {example['positive_ids'].numpy()}") print(f"Negative IDs: {example['negative_ids'].numpy()}") print("-" * 50) def main(): tf.keras.backend.clear_session() # Validate TFRecord # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3) # Init env env = EnvironmentSetup() env.initialize() # Training config EPOCHS = 20 TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord' CHECKPOINT_DIR = 'checkpoints/' batch_size = 32 # Initialize config and chatbot model config = ChatbotConfig() chatbot = RetrievalChatbot(config, mode='training') # Check for existing checkpoint latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR) initial_epoch = 0 if latest_checkpoint: try: ckpt_number = int(latest_checkpoint.split('ckpt-')[-1]) initial_epoch = ckpt_number logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}") except (IndexError, ValueError): logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}") initial_epoch = 0 # Train chatbot.train_model( tfrecord_file_path=TF_RECORD_FILE_PATH, epochs=EPOCHS, batch_size=batch_size, use_lr_schedule=True, test_mode=True, checkpoint_dir=CHECKPOINT_DIR, initial_epoch=initial_epoch ) # Save model_save_path = env.training_dirs['base'] / 'final_model' chatbot.save_models(model_save_path) # Plot plotter = Plotter(save_dir=env.training_dirs['plots']) plotter.plot_training_history(chatbot.history) if __name__ == "__main__": main()