File size: 2,785 Bytes
9decf80
f7b283c
 
fc5f33b
f7b283c
 
 
 
9b5daff
 
 
 
 
 
 
 
cc2577d
9b5daff
 
cc2577d
9b5daff
 
 
 
 
 
 
f7b283c
cc2577d
9b5daff
cc2577d
5b413d1
9b5daff
cc2577d
f7b283c
 
 
cc2577d
9b5daff
f5346f7
fc5f33b
9b5daff
cc2577d
f7b283c
cc2577d
 
9b5daff
 
cc2577d
d53c64b
 
 
 
 
 
 
 
 
 
5b413d1
cc2577d
5b413d1
f5346f7
f7b283c
 
9decf80
d53c64b
 
 
f7b283c
 
cc2577d
f7b283c
 
 
cc2577d
fc5f33b
f7b283c
cc2577d
f7b283c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from plotter import Plotter

from logger_config import config_logger
logger = config_logger(__name__)

def inspect_tfrecord(tfrecord_file_path, num_examples=3):
    def parse_example(example_proto):
        feature_description = {
            'query_ids': tf.io.FixedLenFeature([512], tf.int64),  # Adjust max_length if different
            'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
            'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64),  # Adjust neg_samples if different
        }
        return tf.io.parse_single_example(example_proto, feature_description)
    
    dataset = tf.data.TFRecordDataset(tfrecord_file_path)
    dataset = dataset.map(parse_example)
    
    for i, example in enumerate(dataset.take(num_examples)):
        print(f"Example {i+1}:")
        print(f"Query IDs: {example['query_ids'].numpy()}")
        print(f"Positive IDs: {example['positive_ids'].numpy()}")
        print(f"Negative IDs: {example['negative_ids'].numpy()}")
        print("-" * 50)
        
def main():
    tf.keras.backend.clear_session()
    
    # Validate TFRecord 
    # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
    
    # Init env
    env = EnvironmentSetup()
    env.initialize()
    
    # Training config
    EPOCHS = 20
    TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
    CHECKPOINT_DIR = 'checkpoints/'
    
    batch_size = 32
    
    # Initialize config and chatbot model
    config = ChatbotConfig()
    chatbot = RetrievalChatbot(config, mode='training')
    
    # Check for existing checkpoint
    latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
    initial_epoch = 0
    if latest_checkpoint:
        try:
            ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
            initial_epoch = ckpt_number
            logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
        except (IndexError, ValueError):
            logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
            initial_epoch = 0
    
    # Train
    chatbot.train_model(
        tfrecord_file_path=TF_RECORD_FILE_PATH,
        epochs=EPOCHS,
        batch_size=batch_size,
        use_lr_schedule=True,
        test_mode=True,
        checkpoint_dir=CHECKPOINT_DIR,
        initial_epoch=initial_epoch
    )
    
    # Save
    model_save_path = env.training_dirs['base'] / 'final_model'
    chatbot.save_models(model_save_path)
    
    # Plot
    plotter = Plotter(save_dir=env.training_dirs['plots'])
    plotter.plot_training_history(chatbot.history)
    
if __name__ == "__main__":
    main()