File size: 5,500 Bytes
70884da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
import argparse
from datetime import datetime
# Define the function to create the first model
def create_simple_model(instance_shape, max_length):
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
norm_1 = layers.LayerNormalization()(aggregated)
output = layers.Dense(1, activation="sigmoid")(norm_1)
return Model(inputs, output)
# Define the function to create the second model with attention
def create_simple_model2(instance_shape, max_length, num_heads=4, key_dim=64):
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
attention_output, attention_scores = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=64,
dropout=0.1,
use_bias=True
)(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
aggregated = layers.GlobalAveragePooling1D()(attention_output)
norm_1 = layers.LayerNormalization()(aggregated)
output = layers.Dense(1, activation="sigmoid")(norm_1)
return Model(inputs, output)
# Function to compute class weights
def compute_class_weights(labels):
negative_count = len(np.where(labels == 0)[0])
positive_count = len(np.where(labels == 1)[0])
total_count = negative_count + positive_count
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
# Function to generate batches of data
def data_generator(data, labels, batch_size=1):
class_weights = compute_class_weights(labels)
while True:
for i in range(0, len(data), batch_size):
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
yield batch_data, batch_labels, batch_weights
# Learning rate scheduler
def lr_scheduler(epoch, lr):
decay_rate = 0.1
decay_step = 10
if epoch % decay_step == 0 and epoch:
return lr * decay_rate
return lr
# Function to train the model
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
model_path = os.path.join(save_dir, "risk_classifier_model.h5")
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
train_gen = data_generator(train_data, train_labels)
val_gen = data_generator(val_data, val_labels)
model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
return model
if __name__ == "__main__":
# Command line arguments
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
parser.add_argument('--model_type', type=str, default='model1', choices=['model1', 'model2'], help='Type of model to use: model1 (default) or model2.')
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# Load the preprocessed data
data = np.load(args.data_file)
train_X, train_Y = data['train_X'], data['train_Y']
validate_X, validate_Y = data['validate_X'], data['validate_Y']
# Create the model based on the selected type
instance_shape = (train_X.shape[-1],)
max_length = train_X.shape[1]
if args.model_type == 'model2':
model = create_simple_model2(instance_shape, max_length)
else:
model = create_simple_model(instance_shape, max_length)
# Train the model
trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
# Final message after training and saving the model
print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")
|