File size: 2,785 Bytes
9decf80 f7b283c fc5f33b f7b283c 9b5daff cc2577d 9b5daff cc2577d 9b5daff f7b283c cc2577d 9b5daff cc2577d 5b413d1 9b5daff cc2577d f7b283c cc2577d 9b5daff f5346f7 fc5f33b 9b5daff cc2577d f7b283c cc2577d 9b5daff cc2577d d53c64b 5b413d1 cc2577d 5b413d1 f5346f7 f7b283c 9decf80 d53c64b f7b283c cc2577d f7b283c cc2577d fc5f33b f7b283c cc2577d f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from plotter import Plotter
from logger_config import config_logger
logger = config_logger(__name__)
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
def parse_example(example_proto):
feature_description = {
'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
}
return tf.io.parse_single_example(example_proto, feature_description)
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
dataset = dataset.map(parse_example)
for i, example in enumerate(dataset.take(num_examples)):
print(f"Example {i+1}:")
print(f"Query IDs: {example['query_ids'].numpy()}")
print(f"Positive IDs: {example['positive_ids'].numpy()}")
print(f"Negative IDs: {example['negative_ids'].numpy()}")
print("-" * 50)
def main():
tf.keras.backend.clear_session()
# Validate TFRecord
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
# Init env
env = EnvironmentSetup()
env.initialize()
# Training config
EPOCHS = 20
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
CHECKPOINT_DIR = 'checkpoints/'
batch_size = 32
# Initialize config and chatbot model
config = ChatbotConfig()
chatbot = RetrievalChatbot(config, mode='training')
# Check for existing checkpoint
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
initial_epoch = 0
if latest_checkpoint:
try:
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
initial_epoch = ckpt_number
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
except (IndexError, ValueError):
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
initial_epoch = 0
# Train
chatbot.train_model(
tfrecord_file_path=TF_RECORD_FILE_PATH,
epochs=EPOCHS,
batch_size=batch_size,
use_lr_schedule=True,
test_mode=True,
checkpoint_dir=CHECKPOINT_DIR,
initial_epoch=initial_epoch
)
# Save
model_save_path = env.training_dirs['base'] / 'final_model'
chatbot.save_models(model_save_path)
# Plot
plotter = Plotter(save_dir=env.training_dirs['plots'])
plotter.plot_training_history(chatbot.history)
if __name__ == "__main__":
main() |