File size: 5,500 Bytes
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
import argparse
from datetime import datetime

# Define the function to create the first model
def create_simple_model(instance_shape, max_length):
    inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
    flatten = layers.TimeDistributed(layers.Flatten())(inputs)
    dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
    dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
    dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
    dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
    aggregated = layers.GlobalAveragePooling1D()(dropout_2)
    norm_1 = layers.LayerNormalization()(aggregated)
    output = layers.Dense(1, activation="sigmoid")(norm_1)
    return Model(inputs, output)

# Define the function to create the second model with attention
def create_simple_model2(instance_shape, max_length, num_heads=4, key_dim=64):
    inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
    flatten = layers.TimeDistributed(layers.Flatten())(inputs)
    dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
    dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
    dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
    dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
    attention_output, attention_scores = layers.MultiHeadAttention(
        num_heads=num_heads,
        key_dim=key_dim,
        value_dim=64,
        dropout=0.1,
        use_bias=True
    )(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
    aggregated = layers.GlobalAveragePooling1D()(attention_output)
    norm_1 = layers.LayerNormalization()(aggregated)
    output = layers.Dense(1, activation="sigmoid")(norm_1)
    return Model(inputs, output)

# Function to compute class weights
def compute_class_weights(labels):
    negative_count = len(np.where(labels == 0)[0])
    positive_count = len(np.where(labels == 1)[0])
    total_count = negative_count + positive_count
    return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}

# Function to generate batches of data
def data_generator(data, labels, batch_size=1):
    class_weights = compute_class_weights(labels)
    while True:
        for i in range(0, len(data), batch_size):
            batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
            batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
            batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
            yield batch_data, batch_labels, batch_weights

# Learning rate scheduler
def lr_scheduler(epoch, lr):
    decay_rate = 0.1
    decay_step = 10
    if epoch % decay_step == 0 and epoch:
        return lr * decay_rate
    return lr

# Function to train the model
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
    model_path = os.path.join(save_dir, "risk_classifier_model.h5")
    model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
    train_gen = data_generator(train_data, train_labels)
    val_gen = data_generator(val_data, val_labels)
    model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
    return model

if __name__ == "__main__":
    # Command line arguments
    parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
    parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
    parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
    parser.add_argument('--model_type', type=str, default='model1', choices=['model1', 'model2'], help='Type of model to use: model1 (default) or model2.')

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # Load the preprocessed data
    data = np.load(args.data_file)
    train_X, train_Y = data['train_X'], data['train_Y']
    validate_X, validate_Y = data['validate_X'], data['validate_Y']

    # Create the model based on the selected type
    instance_shape = (train_X.shape[-1],)
    max_length = train_X.shape[1]
    
    if args.model_type == 'model2':
        model = create_simple_model2(instance_shape, max_length)
    else:
        model = create_simple_model(instance_shape, max_length)

    # Train the model
    trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)

    # Final message after training and saving the model
    print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")