JoeArmani
style updates
cc2577d
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from plotter import Plotter
from logger_config import config_logger
logger = config_logger(__name__)
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
def parse_example(example_proto):
feature_description = {
'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
}
return tf.io.parse_single_example(example_proto, feature_description)
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
dataset = dataset.map(parse_example)
for i, example in enumerate(dataset.take(num_examples)):
print(f"Example {i+1}:")
print(f"Query IDs: {example['query_ids'].numpy()}")
print(f"Positive IDs: {example['positive_ids'].numpy()}")
print(f"Negative IDs: {example['negative_ids'].numpy()}")
print("-" * 50)
def main():
tf.keras.backend.clear_session()
# Validate TFRecord
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
# Init env
env = EnvironmentSetup()
env.initialize()
# Training config
EPOCHS = 20
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
CHECKPOINT_DIR = 'checkpoints/'
batch_size = 32
# Initialize config and chatbot model
config = ChatbotConfig()
chatbot = RetrievalChatbot(config, mode='training')
# Check for existing checkpoint
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
initial_epoch = 0
if latest_checkpoint:
try:
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
initial_epoch = ckpt_number
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
except (IndexError, ValueError):
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
initial_epoch = 0
# Train
chatbot.train_model(
tfrecord_file_path=TF_RECORD_FILE_PATH,
epochs=EPOCHS,
batch_size=batch_size,
use_lr_schedule=True,
test_mode=True,
checkpoint_dir=CHECKPOINT_DIR,
initial_epoch=initial_epoch
)
# Save
model_save_path = env.training_dirs['base'] / 'final_model'
chatbot.save_models(model_save_path)
# Plot
plotter = Plotter(save_dir=env.training_dirs['plots'])
plotter.plot_training_history(chatbot.history)
if __name__ == "__main__":
main()