VatsalPatel18 commited on
Commit
70884da
·
verified ·
1 Parent(s): 9a4b2a2

Upload 19 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.8-slim-buster
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install system and Python dependencies
8
+ RUN apt-get update && \
9
+ apt-get install -y build-essential openslide-tools libgl1-mesa-glx && \
10
+ apt-get clean && \
11
+ rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy the entire genomic_plip_model directory contents into the container at /app
14
+ RUN adduser --disabled-password --gecos '' myuser
15
+ USER myuser
16
+
17
+ COPY ./ /app/
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Create a non-root user and switch to it for security
22
+
23
+ EXPOSE 8888
24
+
25
+ # Set the entrypoint to a shell command
26
+ ENTRYPOINT ["/bin/bash"]
ReadME.pdf ADDED
Binary file (68.2 kB). View file
 
benchmark_train_inception.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import json
5
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
6
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
7
+ import argparse
8
+ import pandas as pd
9
+
10
+ # Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
11
+ def compute_additional_metrics(generator, model):
12
+ y_true = generator.classes
13
+ y_pred_prob = model.predict(generator)
14
+ y_pred = np.argmax(y_pred_prob, axis=1)
15
+ auc = roc_auc_score(y_true, y_pred_prob[:, 1])
16
+ precision = precision_score(y_true, y_pred, average='macro')
17
+ recall = recall_score(y_true, y_pred, average='macro')
18
+ f1 = f1_score(y_true, y_pred, average='macro')
19
+ accuracy = accuracy_score(y_true, y_pred)
20
+ return auc, precision, recall, f1, accuracy, y_pred_prob
21
+
22
+ # Function to save evaluation metrics
23
+ def save_evaluation_metrics(generator, model, dataset_name, save_dir):
24
+ auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
25
+ metrics = {
26
+ 'auc': auc,
27
+ 'precision': precision,
28
+ 'recall': recall,
29
+ 'f1_score': f1,
30
+ 'accuracy': accuracy
31
+ }
32
+ # Save predictions
33
+ np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
34
+ return metrics
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser(description='Train and evaluate InceptionV3 on benchmark datasets.')
38
+ parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
39
+ parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
40
+ parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
41
+
42
+ args = parser.parse_args()
43
+
44
+ train_dir = os.path.join(args.dataset_dir, 'train')
45
+ validate_dir = os.path.join(args.dataset_dir, 'validate')
46
+ test_dir = os.path.join(args.dataset_dir, 'test')
47
+ test2_dir = os.path.join(args.dataset_dir, 'test2')
48
+
49
+ os.makedirs(args.save_dir, exist_ok=True)
50
+
51
+ # Set up InceptionV3 model
52
+ with tf.device('GPU:0'):
53
+ inception = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet', input_shape=(299, 299, 3))
54
+ last_layer = inception.get_layer('mixed10')
55
+ last_output = last_layer.output
56
+ x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
57
+ x = tf.keras.layers.Dense(2, activation='softmax')(x) # Assuming binary classification
58
+ model = tf.keras.Model(inputs=inception.input, outputs=x)
59
+ model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])
60
+
61
+ # Image data generators
62
+ train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
63
+ validate_datagen = IDG(rescale=1/255.0)
64
+ test_datagen = IDG(rescale=1/255.0)
65
+
66
+ train_generator = train_datagen.flow_from_directory(train_dir, target_size=(299, 299),
67
+ class_mode='categorical', batch_size=64)
68
+ validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(299, 299),
69
+ class_mode='categorical', batch_size=64)
70
+ test_generator = test_datagen.flow_from_directory(test_dir, target_size=(299, 299),
71
+ class_mode='categorical', batch_size=64)
72
+ test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(299, 299),
73
+ class_mode='categorical', batch_size=64)
74
+
75
+ # Training the model
76
+ hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)
77
+
78
+ # Save the trained model
79
+ model.save(os.path.join(args.save_dir, 'inception_model.hdf5'))
80
+
81
+ # Save training history separately
82
+ training_log = {
83
+ 'loss': hist.history['loss'],
84
+ 'val_loss': hist.history['val_loss'],
85
+ 'accuracy': hist.history['accuracy'],
86
+ 'val_accuracy': hist.history['val_accuracy'],
87
+ 'recall': hist.history['recall'],
88
+ 'val_recall': hist.history['val_recall'],
89
+ 'precision': hist.history['precision'],
90
+ 'val_precision': hist.history['val_precision']
91
+ }
92
+ with open(os.path.join(args.save_dir, 'training_log.json'), 'w') as f:
93
+ json.dump(training_log, f)
94
+
95
+ # Evaluate the model on each dataset and save metrics
96
+ train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
97
+ validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
98
+ test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
99
+ test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)
100
+
101
+ # Save the evaluation metrics in a JSON file
102
+ evaluation_metrics = {
103
+ 'train_metrics': train_metrics,
104
+ 'validate_metrics': validate_metrics,
105
+ 'test_metrics': test_metrics,
106
+ 'test2_metrics': test2_metrics
107
+ }
108
+
109
+ with open(os.path.join(args.save_dir, 'evaluation_metrics.json'), 'w') as f:
110
+ json.dump(evaluation_metrics, f)
111
+
112
+ print("Training and evaluation metrics saved.")
benchmark_train_resnet50.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import json
5
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
6
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
7
+ import argparse
8
+ import pandas as pd
9
+
10
+ # Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
11
+ def compute_additional_metrics(generator, model):
12
+ y_true = generator.classes
13
+ y_pred_prob = model.predict(generator)
14
+ y_pred = np.argmax(y_pred_prob, axis=1)
15
+ auc = roc_auc_score(y_true, y_pred_prob[:, 1])
16
+ precision = precision_score(y_true, y_pred, average='macro')
17
+ recall = recall_score(y_true, y_pred, average='macro')
18
+ f1 = f1_score(y_true, y_pred, average='macro')
19
+ accuracy = accuracy_score(y_true, y_pred)
20
+ return auc, precision, recall, f1, accuracy, y_pred_prob
21
+
22
+ # Function to save evaluation metrics
23
+ def save_evaluation_metrics(generator, model, dataset_name, save_dir):
24
+ auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
25
+ metrics = {
26
+ 'auc': auc,
27
+ 'precision': precision,
28
+ 'recall': recall,
29
+ 'f1_score': f1,
30
+ 'accuracy': accuracy
31
+ }
32
+ # Save predictions
33
+ np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
34
+ return metrics
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser(description='Train and evaluate ResNet50 on benchmark datasets.')
38
+ parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
39
+ parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
40
+ parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')
41
+
42
+ args = parser.parse_args()
43
+
44
+ train_dir = os.path.join(args.dataset_dir, 'train')
45
+ validate_dir = os.path.join(args.dataset_dir, 'validate')
46
+ test_dir = os.path.join(args.dataset_dir, 'test')
47
+ test2_dir = os.path.join(args.dataset_dir, 'test2')
48
+
49
+ os.makedirs(args.save_dir, exist_ok=True)
50
+
51
+ # Set up ResNet50 model
52
+ with tf.device('GPU:0'):
53
+ resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
54
+ last_layer = resnet.get_layer('conv5_block3_out')
55
+ last_output = last_layer.output
56
+ x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
57
+ x = tf.keras.layers.Dense(2, activation='softmax')(x) # Assuming binary classification
58
+ model = tf.keras.Model(inputs=resnet.input, outputs=x)
59
+ model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])
60
+
61
+ # Image data generators
62
+ train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
63
+ validate_datagen = IDG(rescale=1/255.0)
64
+ test_datagen = IDG(rescale=1/255.0)
65
+
66
+ batch_size = 64
67
+
68
+ train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
69
+ class_mode='categorical', batch_size=batch_size)
70
+ validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(224, 224),
71
+ class_mode='categorical', batch_size=batch_size)
72
+ test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224),
73
+ class_mode='categorical', batch_size=batch_size)
74
+ test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(224, 224),
75
+ class_mode='categorical', batch_size=batch_size)
76
+
77
+ # Training the model
78
+ hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)
79
+
80
+ # Save the trained model
81
+ model.save(os.path.join(args.save_dir, 'risk_classifier_resnet_model.hdf5'))
82
+
83
+ # Save training history separately
84
+ training_log = {
85
+ 'loss': hist.history['loss'],
86
+ 'val_loss': hist.history['val_loss'],
87
+ 'accuracy': hist.history['accuracy'],
88
+ 'val_accuracy': hist.history['val_accuracy'],
89
+ 'recall': hist.history['recall'],
90
+ 'val_recall': hist.history['val_recall'],
91
+ 'precision': hist.history['precision'],
92
+ 'val_precision': hist.history['val_precision']
93
+ }
94
+ with open(os.path.join(args.save_dir, 'resnet_training_log.json'), 'w') as f:
95
+ json.dump(training_log, f)
96
+
97
+ # Evaluate the model on each dataset and save metrics
98
+ train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
99
+ validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
100
+ test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
101
+ test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)
102
+
103
+ # Save the evaluation metrics in a JSON file
104
+ evaluation_metrics = {
105
+ 'train_metrics': train_metrics,
106
+ 'validate_metrics': validate_metrics,
107
+ 'test_metrics': test_metrics,
108
+ 'test2_metrics': test2_metrics
109
+ }
110
+
111
+ with open(os.path.join(args.save_dir, 'resnet_evaluation_metrics.json'), 'w') as f:
112
+ json.dump(evaluation_metrics, f)
113
+
114
+ print("Training and evaluation metrics saved.")
evaluate_on_test_cohort2.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import json
6
+ import tensorflow as tf
7
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
8
+ import argparse
9
+
10
+ # Function to load and preprocess the dataset
11
+ def load_and_preprocess_data(metadata_file, data_dir):
12
+ dff = pd.read_csv(metadata_file, skiprows=0)
13
+ if 'Unnamed: 0' in dff.columns:
14
+ del dff['Unnamed: 0']
15
+
16
+ # Filter and map classes to 0 and 1
17
+ classified_df = dff[dff['Class'].isin([1, 3])]
18
+ classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
19
+ df = classified_df.set_index('PatientID')
20
+
21
+ # Filter for patients that have corresponding WSI data
22
+ available_patients = set(os.listdir(data_dir))
23
+ df = df.loc[df.index.intersection(available_patients)]
24
+ df = df.sample(frac=1)
25
+
26
+ return df
27
+
28
+ # Function to create bags of tiles
29
+ def create_bags(df, data_dir):
30
+ data = {'test2': {'X': [], 'Y': []}}
31
+ for pID, row in df.iterrows():
32
+ fol_p = os.path.join(data_dir, pID)
33
+ tiles = os.listdir(fol_p)
34
+ tile_data = []
35
+ for tile in tiles:
36
+ tile_p = os.path.join(fol_p, tile)
37
+ np1 = torch.load(tile_p).numpy()
38
+ tile_data.append(np1)
39
+ bag = np.squeeze(tile_data, axis=1)
40
+ bag_label = row['Class']
41
+ data['test2']['X'].append(bag)
42
+ data['test2']['Y'].append(np.array([bag_label]))
43
+ data['test2']['X'] = np.array(data['test2']['X'])
44
+ data['test2']['Y'] = np.array(data['test2']['Y'])
45
+ print(f"Data[test2]['X'] shape: {data['test2']['X'].shape}, dtype: {data['test2']['X'].dtype}")
46
+ return data
47
+
48
+ # Function to pad the data to ensure uniform bag length
49
+ def prepare_data_with_padding(data, max_length=2000):
50
+ padded_data = []
51
+ for bag in data:
52
+ if len(bag) < max_length:
53
+ padding = np.zeros((max_length - len(bag), bag.shape[1]))
54
+ padded_bag = np.vstack((bag, padding))
55
+ else:
56
+ padded_bag = bag
57
+ padded_data.append(padded_bag)
58
+ return np.array(padded_data)
59
+
60
+ # Function to compute additional metrics using sklearn
61
+ def compute_additional_metrics(X, Y, model):
62
+ predictions = model.predict(X).flatten()
63
+ predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
64
+ auc = roc_auc_score(Y, predictions)
65
+ precision = precision_score(Y, predictions_binary)
66
+ recall = recall_score(Y, predictions_binary)
67
+ f1 = f1_score(Y, predictions_binary)
68
+ return auc, precision, recall, f1, predictions
69
+
70
+ # Function to evaluate the model on a given dataset using sklearn metrics
71
+ def evaluate_dataset(model, X, Y, dataset_name, save_dir):
72
+ # Evaluate using TensorFlow's model.evaluate() for loss and accuracy
73
+ eval_metrics = model.evaluate(X, Y, verbose=0)
74
+
75
+ # Compute additional metrics using sklearn
76
+ auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
77
+ metrics = {
78
+ 'loss': eval_metrics[0],
79
+ 'accuracy': eval_metrics[1],
80
+ 'auc': auc,
81
+ 'precision': precision,
82
+ 'recall': recall,
83
+ 'f1_score': f1
84
+ }
85
+
86
+ # Save the predictions for each sample
87
+ np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
88
+
89
+ return metrics
90
+
91
+ if __name__ == "__main__":
92
+ # Command line arguments
93
+ parser = argparse.ArgumentParser(description='Evaluate a trained model on a secondary test dataset (test2).')
94
+ parser.add_argument('--metadata_file', type=str, required=True, help='Path to the metadata CSV file for test2.')
95
+ parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the extracted tissue features.')
96
+ parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model file.')
97
+ parser.add_argument('--save_dir', type=str, default='./evaluation_results_test2/', help='Directory to save evaluation results.')
98
+
99
+ args = parser.parse_args()
100
+
101
+ if not os.path.exists(args.save_dir):
102
+ os.makedirs(args.save_dir)
103
+
104
+ # Load and preprocess the test2 data
105
+ df_test2 = load_and_preprocess_data(args.metadata_file, args.data_dir)
106
+ data_test2 = create_bags(df_test2, args.data_dir)
107
+
108
+ # Prepare the test2 data with padding
109
+ test2_X = prepare_data_with_padding(data_test2['test2']['X'], max_length=2000)
110
+ test2_Y = np.array(data_test2['test2']['Y']).flatten()
111
+
112
+ # Load the saved model
113
+ model = tf.keras.models.load_model(args.model_path)
114
+
115
+ # Evaluate the model on the test2 dataset
116
+ test2_metrics = evaluate_dataset(model, test2_X, test2_Y, "test2", args.save_dir)
117
+
118
+ # Save the metrics to a JSON file
119
+ with open(os.path.join(args.save_dir, 'evaluation_metrics_test2.json'), 'w') as f:
120
+ json.dump(test2_metrics, f, indent=4)
121
+
122
+ print("Evaluation metrics saved to evaluation_metrics_test2.json")
evaluate_risk_classifier.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
5
+ import argparse
6
+ import json
7
+ import pandas as pd
8
+
9
+ # Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
10
+ def compute_additional_metrics(X, Y, model):
11
+ predictions = model.predict(X).flatten()
12
+ predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
13
+ auc = roc_auc_score(Y, predictions)
14
+ precision = precision_score(Y, predictions_binary)
15
+ recall = recall_score(Y, predictions_binary)
16
+ f1 = f1_score(Y, predictions_binary)
17
+ return auc, precision, recall, f1, predictions
18
+
19
+ # Function to evaluate the model on a given dataset
20
+ def evaluate_dataset(model, X, Y, dataset_name, save_dir):
21
+ eval_metrics = model.evaluate(X, Y, verbose=0)
22
+ auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
23
+ metrics = {
24
+ 'loss': eval_metrics[0],
25
+ 'accuracy': eval_metrics[1],
26
+ 'auc': auc,
27
+ 'precision': precision,
28
+ 'recall': recall,
29
+ 'f1_score': f1
30
+ }
31
+
32
+ # Save the predictions for each sample
33
+ np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
34
+
35
+ return metrics
36
+
37
+ # Function to evaluate the model on train, validate, and test datasets
38
+ def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
39
+ train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
40
+ validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
41
+ test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
42
+
43
+ metrics = {
44
+ 'train': train_metrics,
45
+ 'validate': validate_metrics,
46
+ 'test': test_metrics
47
+ }
48
+
49
+ # Display the metrics in a tabular format
50
+ metrics_df = pd.DataFrame(metrics).T
51
+ print(metrics_df.to_string())
52
+
53
+ # Save metrics to a JSON file
54
+ with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
55
+ json.dump(metrics, f, indent=4)
56
+
57
+ print("Evaluation metrics saved to evaluation_metrics.json")
58
+
59
+ return metrics
60
+
61
+ if __name__ == "__main__":
62
+ # Command line arguments
63
+ parser = argparse.ArgumentParser(description='Evaluate a trained multiple instance learning classifier on risk data.')
64
+ parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training, validation, and test data.')
65
+ parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model file.')
66
+ parser.add_argument('--save_dir', type=str, default='./evaluation_results/', help='Directory to save the evaluation results.')
67
+
68
+ args = parser.parse_args()
69
+
70
+ if not os.path.exists(args.save_dir):
71
+ os.makedirs(args.save_dir)
72
+
73
+ # Load the preprocessed data
74
+ data = np.load(args.data_file)
75
+ train_X, train_Y = data['train_X'], data['train_Y']
76
+ validate_X, validate_Y = data['validate_X'], data['validate_Y']
77
+ test_X, test_Y = data['test_X'], data['test_Y']
78
+
79
+ # Load the saved model
80
+ model = tf.keras.models.load_model(args.model_path)
81
+
82
+ # Evaluate the model
83
+ metrics = evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)
extract_omics_aligned_tiles_features.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from pathlib import Path
5
+ import argparse
6
+ from scripts.genomic_plip_model import GenomicPLIPModel
7
+ from transformers import CLIPVisionModel
8
+
9
+ class PatientTileDataset(Dataset):
10
+ def __init__(self, data_dir, model, save_dir):
11
+ super().__init__()
12
+ self.data_dir = data_dir
13
+ self.model = model
14
+ self.save_dir = Path(save_dir)
15
+ self.files = []
16
+ for patient_id in os.listdir(data_dir):
17
+ patient_dir = os.path.join(data_dir, patient_id)
18
+ if os.path.isdir(patient_dir):
19
+ for f in os.listdir(patient_dir):
20
+ if f.endswith('.pt'):
21
+ self.files.append((os.path.join(patient_dir, f), patient_id))
22
+
23
+ def __len__(self):
24
+ return len(self.files)
25
+
26
+ def __getitem__(self, idx):
27
+ file_path, patient_id = self.files[idx]
28
+ data = torch.load(file_path)
29
+ tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0) # Add batch dimension
30
+ with torch.no_grad():
31
+ vision_features, _ = self.model(pixel_values=tile_data, score_vector=torch.zeros(1, 4))
32
+ feature_path = self.save_dir / patient_id / os.path.basename(file_path)
33
+ feature_path.parent.mkdir(parents=True, exist_ok=True)
34
+ torch.save(vision_features, feature_path)
35
+ return feature_path
36
+
37
+ def extract_features(data_dir, save_dir, model_path):
38
+ original_model = CLIPVisionModel.from_pretrained("./plip/")
39
+ custom_model = GenomicPLIPModel(original_model)
40
+ custom_model.load_state_dict(torch.load(model_path))
41
+ custom_model.eval()
42
+
43
+ dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir)
44
+ for _ in dataset:
45
+ pass
46
+
47
+ if __name__ == "__main__":
48
+ parser = argparse.ArgumentParser(description="Extract features from genomic aligned tiles.")
49
+ parser.add_argument('--data_dir', type=str, default='plip_preprocess/', help='Directory containing the pre processed patient data.')
50
+ parser.add_argument('--save_dir', type=str, default='omics_align_features/', help='Directory to save the extracted features.')
51
+ parser.add_argument('--model_path', type=str, default='./save_model/omics_plip.pth', help='Path to the trained model file.')
52
+
53
+ args = parser.parse_args()
54
+
55
+ extract_features(args.data_dir, args.save_dir, args.model_path)
extract_tiles_from_wsi.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from scripts.slide_processor_parallel import SlideProcessor
3
+
4
+ def main():
5
+ parser = argparse.ArgumentParser(description='Process whole slide images from a directory.')
6
+
7
+ # Required arguments
8
+ parser.add_argument('-d', '--directory', type=str, required=True,
9
+ help='Directory containing whole slide image files.')
10
+ parser.add_argument('-o', '--output_dir', type=str, required=True,
11
+ help='Directory to save the processed tiles.')
12
+
13
+ # Optional arguments with defaults
14
+ parser.add_argument('-t', '--tile_size', type=int, default=1024,
15
+ help='Size of the tile in pixels (default: 1024).')
16
+ parser.add_argument('-v', '--overlap', type=int, default=0,
17
+ help='Overlap of tiles in pixels (default: 0).')
18
+ parser.add_argument('-th', '--tissue_threshold', type=float, default=0.65,
19
+ help='Threshold for tissue detection as a float (default: 0.65).')
20
+ parser.add_argument('-w', '--max_workers', type=int, default=30,
21
+ help='Maximum number of worker threads/processes (default: 30).')
22
+
23
+ args = parser.parse_args()
24
+
25
+ # Initialize the SlideProcessor with the parsed arguments
26
+ processor = SlideProcessor(
27
+ tile_size=args.tile_size,
28
+ overlap=args.overlap,
29
+ tissue_threshold=args.tissue_threshold,
30
+ max_workers=args.max_workers
31
+ )
32
+
33
+ # Start the processing
34
+ processor.parallel_process(base_dir=args.directory, output_dir=args.output_dir)
35
+
36
+ if __name__ == '__main__':
37
+ main()
make_dataset_for_benchmark_models.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ import pandas as pd
5
+ from sklearn.model_selection import train_test_split
6
+ import argparse
7
+
8
+ def load_and_preprocess_data(metadata_file, data_dir):
9
+ df = pd.read_csv(metadata_file, skiprows=0)
10
+ if 'Unnamed: 0' in df.columns:
11
+ del df['Unnamed: 0']
12
+
13
+ # Filter and map classes to 0 and 1
14
+ classified_df = df[df['Class'].isin([1, 3])]
15
+ classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
16
+ df = classified_df.set_index('PatientID')
17
+
18
+ # Filter for patients that have corresponding WSI data
19
+ available_patients = set(os.listdir(data_dir))
20
+ df = df.loc[df.index.intersection(available_patients)]
21
+ df = df.sample(frac=1)
22
+
23
+ return df
24
+
25
+ def create_data_splits(df):
26
+ class1 = list(df[df['Class'] == 1].index)
27
+ class0 = list(df[df['Class'] == 0].index)
28
+
29
+ C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
30
+ C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2)
31
+ C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6)
32
+ C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5)
33
+
34
+ X_train = []; X_train.extend(C1_X_train); X_train.extend(C0_X_train)
35
+ X_test = []; X_test.extend(C1_X_test); X_test.extend(C0_X_test)
36
+ X_validate = []; X_validate.extend(C1_X_validate); X_validate.extend(C0_X_validate)
37
+
38
+ random.shuffle(X_train)
39
+ random.shuffle(X_test)
40
+ random.shuffle(X_validate)
41
+
42
+ data_info = {'train': X_train, 'test': X_test, 'validate': X_validate}
43
+
44
+ print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate)))
45
+ print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate)))
46
+
47
+ return data_info
48
+
49
+ def copy_tiles(patient_ids, dest_folder, source_dir, num_tiles_per_patient):
50
+ for pID in patient_ids:
51
+ flp = os.path.join(source_dir, pID)
52
+ if os.path.exists(flp):
53
+ tiles = os.listdir(flp)
54
+ selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
55
+ for tile in selected_tiles:
56
+ tile_p = os.path.join(flp, tile)
57
+ new_p = os.path.join(dest_folder, tile)
58
+ shutil.copy(tile_p, new_p)
59
+ else:
60
+ print(f"Folder not found for patient {pID}")
61
+
62
+ def process_cohorts(primary_metadata, secondary_metadata, source_dir, dataset_dir, num_tiles_per_patient):
63
+ # Create necessary directories if they don't exist
64
+ os.makedirs(os.path.join(dataset_dir, 'train/class1/'), exist_ok=True)
65
+ os.makedirs(os.path.join(dataset_dir, 'train/class0/'), exist_ok=True)
66
+ os.makedirs(os.path.join(dataset_dir, 'test/class1/'), exist_ok=True)
67
+ os.makedirs(os.path.join(dataset_dir, 'test/class0/'), exist_ok=True)
68
+ os.makedirs(os.path.join(dataset_dir, 'validate/class1/'), exist_ok=True)
69
+ os.makedirs(os.path.join(dataset_dir, 'validate/class0/'), exist_ok=True)
70
+ os.makedirs(os.path.join(dataset_dir, 'test2/class1/'), exist_ok=True)
71
+ os.makedirs(os.path.join(dataset_dir, 'test2/class0/'), exist_ok=True)
72
+
73
+ # Load and preprocess primary cohort
74
+ primary_df = load_and_preprocess_data(primary_metadata, source_dir)
75
+ primary_data_info = create_data_splits(primary_df)
76
+
77
+ # Load and preprocess secondary cohort
78
+ secondary_df = load_and_preprocess_data(secondary_metadata, source_dir)
79
+ secondary_data_info = {'test2': secondary_df.index.tolist()}
80
+
81
+ # Copy tiles for the primary cohort
82
+ copy_tiles(primary_data_info['train'], os.path.join(dataset_dir, 'train/class1/'), source_dir, num_tiles_per_patient)
83
+ copy_tiles(primary_data_info['test'], os.path.join(dataset_dir, 'test/class1/'), source_dir, num_tiles_per_patient)
84
+ copy_tiles(primary_data_info['validate'], os.path.join(dataset_dir, 'validate/class1/'), source_dir, num_tiles_per_patient)
85
+
86
+ # Copy tiles for the secondary cohort
87
+ copy_tiles(secondary_data_info['test2'], os.path.join(dataset_dir, 'test2/class1/'), source_dir, num_tiles_per_patient)
88
+
89
+ print("Tiles copying completed for both cohorts.")
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser(description='Create dataset for benchmark models from primary and secondary cohorts.')
93
+ parser.add_argument('--primary_metadata', type=str, required=True, help='Path to the primary cohort metadata CSV file.')
94
+ parser.add_argument('--secondary_metadata', type=str, required=True, help='Path to the secondary cohort metadata CSV file.')
95
+ parser.add_argument('--source_dir', type=str, required=True, help='Directory containing raw tissue tiles.')
96
+ parser.add_argument('--dataset_dir', type=str, required=True, help='Directory to save the processed dataset.')
97
+ parser.add_argument('--num_tiles_per_patient', type=int, default=595, help='Number of tiles to select per patient.')
98
+
99
+ args = parser.parse_args()
100
+
101
+ process_cohorts(args.primary_metadata, args.secondary_metadata, args.source_dir, args.dataset_dir, args.num_tiles_per_patient)
make_train_data_for_omics_plip.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ from sklearn.model_selection import train_test_split
5
+ import argparse
6
+
7
+ def main(num_tiles_per_patient, source_dir, dataset_dir, test_val_size, val_size):
8
+ # Create necessary directories if they don't exist
9
+ os.makedirs(os.path.join(dataset_dir, 'train'), exist_ok=True)
10
+ os.makedirs(os.path.join(dataset_dir, 'test'), exist_ok=True)
11
+ os.makedirs(os.path.join(dataset_dir, 'validate'), exist_ok=True)
12
+
13
+
14
+ with open('./data/tcga-hnscc-patients.json', 'r') as f:
15
+ patient_data = json.load(f)
16
+
17
+ # Separate the patients into study and restricted groups
18
+ study_patients = patient_data['study']
19
+ restricted_patients = patient_data['restricted']
20
+
21
+ # List all patient directories in the source directory
22
+ files = os.listdir(source_dir)
23
+
24
+ # Filter files based on study patients
25
+ study_files = [file for file in files if file in study_patients]
26
+
27
+ # Split the data into train, test, and validation sets
28
+ train, test_val = train_test_split(files, test_size=test_val_size)
29
+ test, val = train_test_split(test_val, test_size=val_size)
30
+
31
+ # Function to process and copy files
32
+ def process_and_copy(file_list, type):
33
+ for file in file_list:
34
+ fol_p = os.path.join(source_dir, file)
35
+ tiles = os.listdir(fol_p)
36
+ selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
37
+ for tile in selected_tiles:
38
+ tile_p = os.path.join(fol_p, tile)
39
+ new_p = os.path.join(dataset_dir, type, tile)
40
+ shutil.copy(tile_p, new_p)
41
+
42
+ # Process and copy files for each dataset
43
+ process_and_copy(train, 'train')
44
+ process_and_copy(test, 'test')
45
+ process_and_copy(val, 'validate')
46
+
47
+ if __name__ == "__main__":
48
+ parser = argparse.ArgumentParser(description='Split data into train, test, and validation sets.')
49
+ parser.add_argument('--num_tiles_per_patient', type=int, default=595,
50
+ help='Number of tiles to select per patient.')
51
+ parser.add_argument('--source_dir', type=str, default='plip_preprocess',
52
+ help='Directory containing patient folders.')
53
+ parser.add_argument('--dataset_dir', type=str, default='Datasets/train_03',
54
+ help='Root directory for the train, test, and validate directories.')
55
+ parser.add_argument('--test_val_size', type=float, default=0.4,
56
+ help='Size of the test and validation sets combined.')
57
+ parser.add_argument('--val_size', type=float, default=0.5,
58
+ help='Proportion of validation set in the test-validation split.')
59
+
60
+ args = parser.parse_args()
61
+
62
+ main(args.num_tiles_per_patient, args.source_dir, args.dataset_dir, args.test_val_size, args.val_size)
make_train_data_for_risk_classification.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import random
6
+ from sklearn.model_selection import train_test_split
7
+ import argparse
8
+
9
+ def prepare_data_with_padding(data, max_length=None):
10
+ if max_length is None:
11
+ max_length = max(len(bag) for bag in data)
12
+ padded_data = []
13
+ for bag in data:
14
+ if len(bag) < max_length:
15
+ padding = np.zeros((max_length - len(bag), bag.shape[1]))
16
+ padded_bag = np.vstack((bag, padding))
17
+ else:
18
+ padded_bag = bag
19
+ padded_data.append(padded_bag)
20
+ return np.array(padded_data), max_length
21
+
22
+ def create_bags(data_info, df13, data_dir):
23
+ data = {'train': {'X': [], 'Y': []}, 'test': {'X': [], 'Y': []}, 'validate': {'X': [], 'Y': []}}
24
+ for split in ['train', 'test', 'validate']:
25
+ for pID in data_info[split]:
26
+ fol_p = os.path.join(data_dir, pID)
27
+ tiles = os.listdir(fol_p)
28
+ tile_data = []
29
+ for tile in tiles:
30
+ tile_p = os.path.join(fol_p, tile)
31
+ np1 = torch.load(tile_p).numpy()
32
+ tile_data.append(np1)
33
+ patient_label = df13.loc[pID, 'Class']
34
+ bag = np.squeeze(tile_data, axis=1)
35
+ bag_label = 1 if patient_label == 1 else 0
36
+ data[split]['X'].append(bag)
37
+ data[split]['Y'].append(np.array([bag_label]))
38
+ data[split]['X'] = np.array(data[split]['X'])
39
+ data[split]['Y'] = np.array(data[split]['Y'])
40
+ print(f"Data[{split}]['X'] shape: {data[split]['X'].shape}, dtype: {data[split]['X'].dtype}")
41
+ return data
42
+
43
+ def process_and_save(data_dir, metadata_file, save_dir):
44
+ # Load and preprocess metadata
45
+ dff = pd.read_csv(metadata_file, skiprows=0)
46
+ del dff['Unnamed: 0']
47
+
48
+ classified_df = dff[dff['Class'].isin([1, 3])]
49
+ classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
50
+ df13 = classified_df.set_index('PatientID')
51
+
52
+ there = set(list(df13.index))
53
+ wsi_there = os.listdir(data_dir)
54
+ use = list(there.intersection(wsi_there))
55
+ df13 = df13.loc[use]
56
+ df13 = df13.sample(frac=1)
57
+
58
+ class1 = list(df13[df13['Class'] == 1].index)
59
+ class0 = list(df13[df13['Class'] == 0].index)
60
+
61
+ C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
62
+ C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2)
63
+ C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6)
64
+ C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5)
65
+
66
+ X_train = C1_X_train + C0_X_train
67
+ X_test = C1_X_test + C0_X_test
68
+ X_validate = C1_X_validate + C0_X_validate
69
+
70
+ random.shuffle(X_train)
71
+ random.shuffle(X_test)
72
+ random.shuffle(X_validate)
73
+
74
+ data_info = {'train': X_train, 'test': X_test, 'validate': X_validate}
75
+
76
+ print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate)))
77
+ print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate)))
78
+
79
+ # Create bags and prepare data with padding
80
+ data = create_bags(data_info, df13, data_dir)
81
+ train_X, _ = prepare_data_with_padding(data['train']['X'], 2000)
82
+ train_Y = np.array(data['train']['Y']).flatten()
83
+ validate_X, _ = prepare_data_with_padding(data['validate']['X'], 2000)
84
+ validate_Y = np.array(data['validate']['Y']).flatten()
85
+ test_X, _ = prepare_data_with_padding(data['test']['X'], 2000)
86
+ test_Y = np.array(data['test']['Y']).flatten()
87
+
88
+ # Save the processed arrays to a single file
89
+ np.savez_compressed(os.path.join(save_dir, 'training_risk_classifier_data.npz'),
90
+ train_X=train_X, train_Y=train_Y,
91
+ validate_X=validate_X, validate_Y=validate_Y,
92
+ test_X=test_X, test_Y=test_Y)
93
+
94
+ print("Data saved successfully in:", os.path.join(save_dir, 'training_risk_classifier_data.npz'))
95
+
96
+ if __name__ == "__main__":
97
+ parser = argparse.ArgumentParser(description='Process, split, and save the data with padding.')
98
+ parser.add_argument('--data_dir', type=str, help='Directory containing the extracted features.')
99
+ parser.add_argument('--metadata_file', type=str, default='data/data1.hnsc.p3.csv', help='CSV file containing the metadata for the samples.')
100
+ parser.add_argument('--save_dir', type=str, default='Datasets', help='Directory to save the processed data.')
101
+
102
+ args = parser.parse_args()
103
+
104
+ if not os.path.exists(args.save_dir):
105
+ os.makedirs(args.save_dir)
106
+
107
+ process_and_save(args.data_dir, args.metadata_file, args.save_dir)
pre_process_tiles.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from scripts.PlipDataProcess import PlipDataProcess # Updated folder name
4
+ from transformers import CLIPImageProcessor
5
+ import argparse
6
+
7
+ def main(csv_file, root_dir, save_dir):
8
+ # Load the CSV file and set 'PatientID' as the index
9
+ df4 = pd.read_csv(csv_file).set_index('PatientID')
10
+
11
+ # List directories in the root directory (assuming each directory corresponds to a patient)
12
+ files = [file for file in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, file))]
13
+
14
+ # Initialize the image processor
15
+ img_processor = CLIPImageProcessor.from_pretrained("./plip/")
16
+
17
+ # Initialize the dataset processing object
18
+ dataset = PlipDataProcess(
19
+ root_dir=root_dir,
20
+ files=files,
21
+ df=df4,
22
+ img_processor=img_processor,
23
+ num_tiles_per_patient=2000,
24
+ max_workers=64,
25
+ save_dir=save_dir
26
+ )
27
+
28
+ # Process each item in the dataset
29
+ for i in range(len(dataset)):
30
+ _ = dataset[i] # Trigger processing of the i-th item
31
+
32
+ if __name__ == '__main__':
33
+ parser = argparse.ArgumentParser(description="Process WSI images and generate tiles")
34
+
35
+ # Define arguments
36
+ parser.add_argument('--csv_file', type=str, required=True, help='Path to the CSV file with patient scores')
37
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory for WSI tiles')
38
+ parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the processed tile data')
39
+
40
+ # Parse arguments
41
+ args = parser.parse_args()
42
+
43
+ # Call the main function with the parsed arguments
44
+ main(csv_file=args.csv_file, root_dir=args.root_dir, save_dir=args.save_dir)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.19.2
2
+ pandas==1.3.4
3
+ matplotlib==3.5.2
4
+ openslide-python==1.1.2
5
+ scikit-image==0.18.1
6
+ scikit-learn==1.2.1
7
+ tqdm==4.62.3
8
+ Pillow==9.4.0
9
+ transformers==4.33.2
10
+ torch==2.0.1
11
+ jupyterlab==3.2.1
12
+ tensorflow==2.6.1
requirementsT.txt ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ anndata==0.9.1
5
+ anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
6
+ appdirs==1.4.4
7
+ argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
8
+ argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1649500328244/work
9
+ astor==0.8.1
10
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
11
+ astunparse==1.6.3
12
+ async-timeout==4.0.3
13
+ attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
14
+ autograd==1.5
15
+ autograd-gamma==0.5.0
16
+ Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work
17
+ backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
18
+ backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
19
+ beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1680888073205/work
20
+ bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
21
+ Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work
22
+ brotlipy==0.7.0
23
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1720457958366/work/certifi
24
+ cffi @ file:///croot/cffi_1670423208954/work
25
+ charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
26
+ cmake==3.27.2
27
+ contourpy @ file:///opt/conda/conda-bld/contourpy_1663827406301/work
28
+ cryptography @ file:///croot/cryptography_1677533068310/work
29
+ cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
30
+ datasets==2.19.1
31
+ debugpy @ file:///home/builder/ci_310/debugpy_1640789504635/work
32
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
33
+ defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
34
+ dill==0.3.8
35
+ entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
36
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
37
+ fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1677336799617/work/dist
38
+ filelock @ file:///croot/filelock_1672387128942/work
39
+ flatbuffers==24.3.25
40
+ flit_core @ file:///croot/flit-core_1679397103445/work/source/flit_core
41
+ fonttools==4.25.0
42
+ formulaic==0.5.2
43
+ frozenlist==1.4.1
44
+ fsspec==2023.6.0
45
+ future==0.18.3
46
+ gast==0.5.4
47
+ git-filter-repo==2.38.0
48
+ gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
49
+ google-pasta==0.2.0
50
+ graphviz==0.20.1
51
+ grpcio==1.64.0
52
+ h5py==3.11.0
53
+ huggingface-hub==0.23.2
54
+ idna @ file:///croot/idna_1666125576474/work
55
+ imageio==2.34.2
56
+ imbalanced-learn==0.11.0
57
+ importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1680895625127/work
58
+ importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1676919000169/work
59
+ interface-meta==1.3.0
60
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work
61
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work
62
+ ipython-genutils==0.2.0
63
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
64
+ Jinja2 @ file:///croot/jinja2_1666908132255/work
65
+ joblib==1.4.2
66
+ json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1600692310011/work
67
+ jsonpickle==3.0.1
68
+ jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
69
+ jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
70
+ jupyter-server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1676473377907/work
71
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1669775088561/work
72
+ jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1674494302491/work
73
+ jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
74
+ jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1680275157923/work
75
+ keras==3.3.3
76
+ keras-tuner==1.4.7
77
+ kiwisolver @ file:///croot/kiwisolver_1672387140495/work
78
+ kt-legacy==1.0.5
79
+ lazy_loader==0.4
80
+ libclang==18.1.1
81
+ lifelines==0.27.4
82
+ lit==16.0.6
83
+ llvmlite==0.40.0
84
+ Markdown==3.6
85
+ markdown-it-py==3.0.0
86
+ MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
87
+ matplotlib @ file:///croot/matplotlib-suite_1679593461707/work
88
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
89
+ matplotlib-venn==0.11.9
90
+ mdurl==0.1.2
91
+ mil==1.0.5
92
+ mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
93
+ mkl-fft==1.3.1
94
+ mkl-random @ file:///home/builder/ci_310/mkl_random_1641843545607/work
95
+ mkl-service==2.4.0
96
+ ml-dtypes==0.3.2
97
+ mpmath==1.2.1
98
+ multidict==6.0.5
99
+ multiprocess==0.70.16
100
+ munkres==1.1.4
101
+ namex==0.0.8
102
+ natsort==8.3.1
103
+ nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1680699279518/work
104
+ nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1680676954923/work
105
+ nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1680629662454/work
106
+ nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1679336765223/work
107
+ nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
108
+ networkx @ file:///croot/networkx_1678964333703/work
109
+ notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1680870634737/work
110
+ notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
111
+ numba==0.57.0
112
+ numexpr @ file:///croot/numexpr_1668713893690/work
113
+ numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work
114
+ nvidia-cublas-cu11==11.10.3.66
115
+ nvidia-cuda-cupti-cu11==11.7.101
116
+ nvidia-cuda-nvrtc-cu11==11.7.99
117
+ nvidia-cuda-runtime-cu11==11.7.99
118
+ nvidia-cudnn-cu11==8.5.0.96
119
+ nvidia-cufft-cu11==10.9.0.58
120
+ nvidia-curand-cu11==10.2.10.91
121
+ nvidia-cusolver-cu11==11.4.0.1
122
+ nvidia-cusparse-cu11==11.7.4.91
123
+ nvidia-nccl-cu11==2.14.3
124
+ nvidia-nvtx-cu11==11.7.91
125
+ opencv-python==4.8.0.76
126
+ openslide-python==1.1.2
127
+ opt-einsum==3.3.0
128
+ optree==0.11.0
129
+ packaging @ file:///croot/packaging_1678965309396/work
130
+ pandas @ file:///croot/pandas_1692289311655/work
131
+ pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
132
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
133
+ patsy==0.5.3
134
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
135
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
136
+ Pillow==9.4.0
137
+ pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
138
+ plotly==5.14.1
139
+ ply==3.11
140
+ pooch @ file:///tmp/build/80754af9/pooch_1623324770023/work
141
+ prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
142
+ prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work
143
+ protobuf==4.25.3
144
+ psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
145
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
146
+ pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
147
+ pyarrow==16.1.0
148
+ pyarrow-hotfix==0.6
149
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
150
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
151
+ pykan==0.0.2
152
+ pynndescent==0.5.10
153
+ pyOpenSSL @ file:///croot/pyopenssl_1677607685877/work
154
+ pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
155
+ PyQt5-sip==12.11.0
156
+ pyrsistent @ file:///home/builder/ci_310/pyrsistent_1640807196327/work
157
+ PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
158
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
159
+ pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work
160
+ pyvis==0.3.2
161
+ PyYAML==6.0.1
162
+ pyzmq @ file:///opt/conda/conda-bld/pyzmq_1657724186960/work
163
+ regex==2023.8.8
164
+ requests @ file:///croot/requests_1678709721434/work
165
+ rich==13.7.1
166
+ rpy2==3.5.11
167
+ safetensors==0.3.3
168
+ scanpy==1.9.3
169
+ scikit-image==0.24.0
170
+ scikit-learn==1.5.1
171
+ scipy==1.14.0
172
+ seaborn @ file:///croot/seaborn_1673479180098/work
173
+ Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
174
+ session-info==1.0.0
175
+ sip @ file:///tmp/abs_44cd77b_pu/croots/recipe/sip_1659012365470/work
176
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
177
+ sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
178
+ soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
179
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
180
+ statsmodels==0.14.0
181
+ stdlib-list==0.8.0
182
+ sympy @ file:///croot/sympy_1668202399572/work
183
+ tenacity==8.2.2
184
+ tensorboard==2.16.2
185
+ tensorboard-data-server==0.7.2
186
+ tensorflow==2.16.1
187
+ tensorflow-io-gcs-filesystem==0.37.0
188
+ termcolor==2.4.0
189
+ terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
190
+ threadpoolctl==3.5.0
191
+ tifffile==2024.7.24
192
+ tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
193
+ tokenizers==0.13.3
194
+ toml @ file:///tmp/build/80754af9/toml_1616166611790/work
195
+ tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
196
+ torch==2.0.0
197
+ torch-geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1679554663466/work
198
+ torchvision==0.15.2
199
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827254365/work
200
+ tqdm @ file:///croot/tqdm_1679561862951/work
201
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
202
+ transformers==4.32.1
203
+ triton==2.0.0
204
+ typing_extensions @ file:///croot/typing_extensions_1669924550328/work
205
+ tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1707747584337/work
206
+ tzlocal==5.0.1
207
+ umap-learn==0.5.3
208
+ urllib3 @ file:///croot/urllib3_1680254681959/work
209
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
210
+ webencodings==0.5.1
211
+ websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
212
+ Werkzeug==3.0.3
213
+ wrapt==1.15.0
214
+ xxhash==3.4.1
215
+ yarl==1.9.4
216
+ yellowbrick==1.5
217
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1677313463193/work
train_GWSIF_classifier.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers, Model
5
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
6
+ import argparse
7
+ import json
8
+
9
+ # Define the function to create the multiple instance learning (MIL) model
10
+ def create_simple_model(instance_shape, max_length):
11
+ inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
12
+ flatten = layers.TimeDistributed(layers.Flatten())(inputs)
13
+ dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
14
+ dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
15
+ dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
16
+ dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
17
+ aggregated = layers.GlobalAveragePooling1D()(dropout_2)
18
+ norm_1 = layers.LayerNormalization()(aggregated)
19
+ output = layers.Dense(1, activation="sigmoid")(norm_1)
20
+ return Model(inputs, output)
21
+
22
+ # Function to compute class weights
23
+ def compute_class_weights(labels):
24
+ negative_count = len(np.where(labels == 0)[0])
25
+ positive_count = len(np.where(labels == 1)[0])
26
+ total_count = negative_count + positive_count
27
+ return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
28
+
29
+ # Function to generate batches of data
30
+ def data_generator(data, labels, batch_size=1):
31
+ class_weights = compute_class_weights(labels)
32
+ while True:
33
+ for i in range(0, len(data), batch_size):
34
+ batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
35
+ batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
36
+ batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
37
+ yield batch_data, batch_labels, batch_weights
38
+
39
+ # Learning rate scheduler
40
+ def lr_scheduler(epoch, lr):
41
+ decay_rate = 0.1
42
+ decay_step = 10
43
+ if epoch % decay_step == 0 and epoch:
44
+ return lr * decay_rate
45
+ return lr
46
+
47
+ # Function to train the model
48
+ def train(train_data, train_labels, val_data, val_labels, model):
49
+ file_path = "/tmp/best_model.weights.h5"
50
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(file_path, monitor="val_loss", verbose=0, mode="min", save_best_only=True, save_weights_only=True)
51
+ early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
52
+ lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
53
+ model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
54
+ train_gen = data_generator(train_data, train_labels)
55
+ val_gen = data_generator(val_data, val_labels)
56
+ model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
57
+ model.load_weights(file_path)
58
+ return model
59
+
60
+ # Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
61
+ def compute_additional_metrics(X, Y, model):
62
+ predictions = model.predict(X).flatten()
63
+ predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
64
+ auc = roc_auc_score(Y, predictions)
65
+ precision = precision_score(Y, predictions_binary)
66
+ recall = recall_score(Y, predictions_binary)
67
+ f1 = f1_score(Y, predictions_binary)
68
+ return auc, precision, recall, f1, predictions
69
+
70
+ # Function to evaluate the model on a given dataset
71
+ def evaluate_dataset(model, X, Y, dataset_name, save_dir):
72
+ eval_metrics = model.evaluate(X, Y, verbose=0)
73
+ auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
74
+ metrics = {
75
+ 'loss': eval_metrics[0],
76
+ 'accuracy': eval_metrics[1],
77
+ 'auc': auc,
78
+ 'precision': precision,
79
+ 'recall': recall,
80
+ 'f1_score': f1
81
+ }
82
+
83
+ # Save the predictions for each sample
84
+ np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
85
+
86
+ return metrics
87
+
88
+ # Function to evaluate the model on train, validate, and test datasets
89
+ def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
90
+ train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
91
+ validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
92
+ test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
93
+
94
+ metrics = {
95
+ 'train': train_metrics,
96
+ 'validate': validate_metrics,
97
+ 'test': test_metrics
98
+ }
99
+
100
+ with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
101
+ json.dump(metrics, f, indent=4)
102
+
103
+ print("Evaluation metrics saved to evaluation_metrics.json")
104
+
105
+ return metrics
106
+
107
+ if __name__ == "__main__":
108
+ # Command line arguments
109
+ parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
110
+ parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
111
+ parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model and evaluation metrics.')
112
+ parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
113
+
114
+ args = parser.parse_args()
115
+
116
+ if not os.path.exists(args.save_dir):
117
+ os.makedirs(args.save_dir)
118
+
119
+ # Load the preprocessed data
120
+ data = np.load(args.data_file)
121
+ train_X, train_Y = data['train_X'], data['train_Y']
122
+ validate_X, validate_Y = data['validate_X'], data['validate_Y']
123
+ test_X, test_Y = data['test_X'], data['test_Y']
124
+
125
+ # Create the model
126
+ instance_shape = (train_X.shape[-1],)
127
+ max_length = train_X.shape[1]
128
+ model = create_simple_model(instance_shape, max_length)
129
+
130
+ # Train the model
131
+ trained_model = train(train_X, train_Y, validate_X, validate_Y, model)
132
+
133
+ # Evaluate the model
134
+ metrics = evaluate_all_datasets(trained_model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)
train_and_evaluate_risk_classifier.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers, Model
5
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
6
+ import argparse
7
+ import json
8
+ import pandas as pd
9
+
10
+ # Define the function to create the multiple instance learning (MIL) model
11
+ def create_simple_model(instance_shape, max_length):
12
+ inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
13
+ flatten = layers.TimeDistributed(layers.Flatten())(inputs)
14
+ dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
15
+ dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
16
+ dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
17
+ dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
18
+ aggregated = layers.GlobalAveragePooling1D()(dropout_2)
19
+ norm_1 = layers.LayerNormalization()(aggregated)
20
+ output = layers.Dense(1, activation="sigmoid")(norm_1)
21
+ return Model(inputs, output)
22
+
23
+ # Function to compute class weights
24
+ def compute_class_weights(labels):
25
+ negative_count = len(np.where(labels == 0)[0])
26
+ positive_count = len(np.where(labels == 1)[0])
27
+ total_count = negative_count + positive_count
28
+ return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
29
+
30
+ # Function to generate batches of data
31
+ def data_generator(data, labels, batch_size=1):
32
+ class_weights = compute_class_weights(labels)
33
+ while True:
34
+ for i in range(0, len(data), batch_size):
35
+ batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
36
+ batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
37
+ batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
38
+ yield batch_data, batch_labels, batch_weights
39
+
40
+ # Learning rate scheduler
41
+ def lr_scheduler(epoch, lr):
42
+ decay_rate = 0.1
43
+ decay_step = 10
44
+ if epoch % decay_step == 0 and epoch:
45
+ return lr * decay_rate
46
+ return lr
47
+
48
+ # Function to train the model
49
+ def train(train_data, train_labels, val_data, val_labels, model, save_dir):
50
+ model_path = os.path.join(save_dir, "best_model.h5")
51
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
52
+ early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
53
+ lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
54
+ model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
55
+ train_gen = data_generator(train_data, train_labels)
56
+ val_gen = data_generator(val_data, val_labels)
57
+ model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
58
+ return model
59
+
60
+ # Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
61
+ def compute_additional_metrics(X, Y, model):
62
+ predictions = model.predict(X).flatten()
63
+ predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
64
+ auc = roc_auc_score(Y, predictions)
65
+ precision = precision_score(Y, predictions_binary)
66
+ recall = recall_score(Y, predictions_binary)
67
+ f1 = f1_score(Y, predictions_binary)
68
+ return auc, precision, recall, f1, predictions
69
+
70
+ # Function to evaluate the model on a given dataset
71
+ def evaluate_dataset(model, X, Y, dataset_name, save_dir):
72
+ eval_metrics = model.evaluate(X, Y, verbose=0)
73
+ auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
74
+ metrics = {
75
+ 'loss': eval_metrics[0],
76
+ 'accuracy': eval_metrics[1],
77
+ 'auc': auc,
78
+ 'precision': precision,
79
+ 'recall': recall,
80
+ 'f1_score': f1
81
+ }
82
+
83
+ # Save the predictions for each sample
84
+ np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
85
+
86
+ return metrics
87
+
88
+ # Function to evaluate the model on train, validate, and test datasets
89
+ def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
90
+ train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
91
+ validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
92
+ test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
93
+
94
+ metrics = {
95
+ 'train': train_metrics,
96
+ 'validate': validate_metrics,
97
+ 'test': test_metrics
98
+ }
99
+
100
+ # Display the metrics in a tabular format
101
+ metrics_df = pd.DataFrame(metrics).T
102
+ print(metrics_df.to_string())
103
+
104
+ # Save metrics to a JSON file
105
+ with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
106
+ json.dump(metrics, f, indent=4)
107
+
108
+ print("Evaluation metrics saved to evaluation_metrics.json")
109
+
110
+ return metrics
111
+
112
+ if __name__ == "__main__":
113
+ # Command line arguments
114
+ parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
115
+ parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
116
+ parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model and evaluation metrics.')
117
+ parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
118
+
119
+ args = parser.parse_args()
120
+
121
+ if not os.path.exists(args.save_dir):
122
+ os.makedirs(args.save_dir)
123
+
124
+ # Load the preprocessed data
125
+ data = np.load(args.data_file)
126
+ train_X, train_Y = data['train_X'], data['train_Y']
127
+ validate_X, validate_Y = data['validate_X'], data['validate_Y']
128
+ test_X, test_Y = data['test_X'], data['test_Y']
129
+
130
+ # Create the model
131
+ instance_shape = (train_X.shape[-1],)
132
+ max_length = train_X.shape[1]
133
+ model = create_simple_model(instance_shape, max_length)
134
+
135
+ # Train the model
136
+ trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
137
+
138
+ # Save the final model after training
139
+ final_model_path = os.path.join(args.save_dir, "risk_classifier_model.h5")
140
+ trained_model.save(final_model_path)
141
+ print(f"Model saved successfully to {final_model_path}")
142
+
143
+ # Evaluate the model
144
+ metrics = evaluate_all_datasets(trained_model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)
train_omics_plip_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from torch import optim
4
+ from torch.utils.data import DataLoader
5
+ from scripts.genomic_plip_model import GenomicPLIPModel
6
+ from scripts.tile_file_dataloader import FlatTileDataset
7
+ from transformers import CLIPVisionModel
8
+
9
+ def train_model(data_dir, model_save_path, pretrained_model_path, lr, num_epochs, train_batch_size, validation_batch_size, num_workers):
10
+
11
+ # Load datasets
12
+ train_dataset = FlatTileDataset(data_dir=f'{data_dir}/train')
13
+ train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
14
+
15
+ validation_dataset = FlatTileDataset(data_dir=f'{data_dir}/validate')
16
+ validation_data_loader = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=False, num_workers=num_workers)
17
+
18
+ # Initialize the model
19
+ base_model = CLIPVisionModel.from_pretrained(pretrained_model_path)
20
+ custom_model = GenomicPLIPModel(base_model)
21
+
22
+ criterion = torch.nn.CosineSimilarity(dim=1)
23
+ optimizer = optim.Adam(custom_model.parameters(), lr=lr)
24
+
25
+
26
+ for epoch in range(num_epochs):
27
+ # Training loop
28
+ custom_model.train()
29
+ train_loss = 0.0
30
+
31
+ for batch_images, batch_scores in train_data_loader:
32
+ optimizer.zero_grad()
33
+
34
+ batch_loss = 0
35
+ for img, score in zip(batch_images, batch_scores):
36
+ vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
37
+ cos_sim = criterion(score_features, vision_features)
38
+ loss = -cos_sim.mean()
39
+
40
+ batch_loss += loss.item()
41
+ loss.backward()
42
+
43
+ optimizer.step()
44
+ train_loss += batch_loss
45
+ print(f"Batch Cosine Similarity {batch_loss:.4f}")
46
+
47
+ avg_train_loss = train_loss / len(train_data_loader)
48
+ print(f"Epoch [{epoch+1}/{num_epochs}], Training Cosine Similarity: {avg_train_loss:.4f}")
49
+
50
+ # Validation loop
51
+ custom_model.eval()
52
+ validation_loss = 0.0
53
+
54
+ with torch.no_grad():
55
+ for batch_images, batch_scores in validation_data_loader:
56
+ batch_loss = 0
57
+ for img, score in zip(batch_images, batch_scores):
58
+ vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
59
+ cos_sim = criterion(score_features, vision_features)
60
+ loss = -cos_sim.mean()
61
+
62
+ batch_loss += loss.item()
63
+
64
+ validation_loss += batch_loss
65
+ print(f"Validation Batch Cosine Similarity {batch_loss:.4f}")
66
+
67
+ avg_validation_loss = validation_loss / len(validation_data_loader)
68
+ print(f"Epoch [{epoch+1}/{num_epochs}], Validation Cosine Similarity: {avg_validation_loss:.4f}")
69
+
70
+ # Save the trained model
71
+ torch.save(custom_model.state_dict(), model_save_path)
72
+
73
+ if __name__ == "__main__":
74
+ parser = argparse.ArgumentParser(description='Train the Genomic PLIP Model')
75
+ parser.add_argument('--data_dir', type=str, default='Datasets/train_03', help='Directory containing the train, validate, and test datasets.')
76
+ parser.add_argument('--model_save_path', type=str, default='genomic_plip.pth', help='Path to save the trained model.')
77
+ parser.add_argument('--pretrained_model_path', type=str, default='./plip', help='Path to the pretrained CLIP model.')
78
+
79
+ parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate for the optimizer.')
80
+ parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs to train for.')
81
+ parser.add_argument('--train_batch_size', type=int, default=128, help='Batch size for the training data loader.')
82
+ parser.add_argument('--validation_batch_size', type=int, default=128, help='Batch size for the validation data loader.')
83
+ parser.add_argument('--num_workers', type=int, default=32, help='Number of worker threads for data loading.')
84
+
85
+
86
+ args = parser.parse_args()
87
+
88
+ train_model(args.data_dir, args.model_save_path, args.pretrained_model_path, args.lr, args.num_epochs, args.train_batch_size, args.validation_batch_size, args.num_workers)
89
+
train_risk_classifier.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers, Model
5
+ import argparse
6
+ from datetime import datetime
7
+
8
+ # Define the function to create the multiple instance learning (MIL) model
9
+ def create_simple_model2(instance_shape, max_length):
10
+ inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
11
+ flatten = layers.TimeDistributed(layers.Flatten())(inputs)
12
+ dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
13
+ dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
14
+ dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
15
+ dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
16
+ aggregated = layers.GlobalAveragePooling1D()(dropout_2)
17
+ norm_1 = layers.LayerNormalization()(aggregated)
18
+ output = layers.Dense(1, activation="sigmoid")(norm_1)
19
+ return Model(inputs, output)
20
+
21
+ def create_simple_model(instance_shape, max_length, num_heads=4, key_dim=64):
22
+ inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
23
+ flatten = layers.TimeDistributed(layers.Flatten())(inputs)
24
+ dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
25
+ dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
26
+ dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
27
+ dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
28
+ attention_output, attention_scores = layers.MultiHeadAttention(
29
+ num_heads=num_heads,
30
+ key_dim=key_dim,
31
+ value_dim=64,
32
+ dropout=0.1,
33
+ use_bias=True
34
+ )(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
35
+ aggregated = layers.GlobalAveragePooling1D()(attention_output)
36
+ norm_1 = layers.LayerNormalization()(aggregated)
37
+ output = layers.Dense(1, activation="sigmoid")(norm_1)
38
+ return Model(inputs, output)
39
+
40
+ # Function to compute class weights
41
+ def compute_class_weights(labels):
42
+ negative_count = len(np.where(labels == 0)[0])
43
+ positive_count = len(np.where(labels == 1)[0])
44
+ total_count = negative_count + positive_count
45
+ return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
46
+
47
+ # Function to generate batches of data
48
+ def data_generator(data, labels, batch_size=1):
49
+ class_weights = compute_class_weights(labels)
50
+ while True:
51
+ for i in range(0, len(data), batch_size):
52
+ batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
53
+ batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
54
+ batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
55
+ yield batch_data, batch_labels, batch_weights
56
+
57
+ # Learning rate scheduler
58
+ def lr_scheduler(epoch, lr):
59
+ decay_rate = 0.1
60
+ decay_step = 10
61
+ if epoch % decay_step == 0 and epoch:
62
+ return lr * decay_rate
63
+ return lr
64
+
65
+ # Function to train the model
66
+ def train(train_data, train_labels, val_data, val_labels, model, save_dir):
67
+ model_path = os.path.join(save_dir, "risk_classifier_model.h5")
68
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
69
+ early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
70
+ lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
71
+ model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
72
+ train_gen = data_generator(train_data, train_labels)
73
+ val_gen = data_generator(val_data, val_labels)
74
+ model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
75
+ return model
76
+
77
+ if __name__ == "__main__":
78
+ # Command line arguments
79
+ parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
80
+ parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
81
+ parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
82
+ parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
83
+
84
+ args = parser.parse_args()
85
+
86
+ if not os.path.exists(args.save_dir):
87
+ os.makedirs(args.save_dir)
88
+
89
+ # Load the preprocessed data
90
+ data = np.load(args.data_file)
91
+ train_X, train_Y = data['train_X'], data['train_Y']
92
+ validate_X, validate_Y = data['validate_X'], data['validate_Y']
93
+
94
+ # Create the model
95
+ instance_shape = (train_X.shape[-1],)
96
+ max_length = train_X.shape[1]
97
+ model = create_simple_model(instance_shape, max_length)
98
+
99
+ # Train the model
100
+ trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
101
+
102
+ # Final message after training and saving the model
103
+ print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")
train_risk_classifier_optional.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers, Model
5
+ import argparse
6
+ from datetime import datetime
7
+
8
+ # Define the function to create the first model
9
+ def create_simple_model(instance_shape, max_length):
10
+ inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
11
+ flatten = layers.TimeDistributed(layers.Flatten())(inputs)
12
+ dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
13
+ dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
14
+ dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
15
+ dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
16
+ aggregated = layers.GlobalAveragePooling1D()(dropout_2)
17
+ norm_1 = layers.LayerNormalization()(aggregated)
18
+ output = layers.Dense(1, activation="sigmoid")(norm_1)
19
+ return Model(inputs, output)
20
+
21
+ # Define the function to create the second model with attention
22
+ def create_simple_model2(instance_shape, max_length, num_heads=4, key_dim=64):
23
+ inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
24
+ flatten = layers.TimeDistributed(layers.Flatten())(inputs)
25
+ dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
26
+ dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
27
+ dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
28
+ dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
29
+ attention_output, attention_scores = layers.MultiHeadAttention(
30
+ num_heads=num_heads,
31
+ key_dim=key_dim,
32
+ value_dim=64,
33
+ dropout=0.1,
34
+ use_bias=True
35
+ )(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
36
+ aggregated = layers.GlobalAveragePooling1D()(attention_output)
37
+ norm_1 = layers.LayerNormalization()(aggregated)
38
+ output = layers.Dense(1, activation="sigmoid")(norm_1)
39
+ return Model(inputs, output)
40
+
41
+ # Function to compute class weights
42
+ def compute_class_weights(labels):
43
+ negative_count = len(np.where(labels == 0)[0])
44
+ positive_count = len(np.where(labels == 1)[0])
45
+ total_count = negative_count + positive_count
46
+ return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
47
+
48
+ # Function to generate batches of data
49
+ def data_generator(data, labels, batch_size=1):
50
+ class_weights = compute_class_weights(labels)
51
+ while True:
52
+ for i in range(0, len(data), batch_size):
53
+ batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
54
+ batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
55
+ batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
56
+ yield batch_data, batch_labels, batch_weights
57
+
58
+ # Learning rate scheduler
59
+ def lr_scheduler(epoch, lr):
60
+ decay_rate = 0.1
61
+ decay_step = 10
62
+ if epoch % decay_step == 0 and epoch:
63
+ return lr * decay_rate
64
+ return lr
65
+
66
+ # Function to train the model
67
+ def train(train_data, train_labels, val_data, val_labels, model, save_dir):
68
+ model_path = os.path.join(save_dir, "risk_classifier_model.h5")
69
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
70
+ early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
71
+ lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
72
+ model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
73
+ train_gen = data_generator(train_data, train_labels)
74
+ val_gen = data_generator(val_data, val_labels)
75
+ model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
76
+ return model
77
+
78
+ if __name__ == "__main__":
79
+ # Command line arguments
80
+ parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
81
+ parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
82
+ parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
83
+ parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
84
+ parser.add_argument('--model_type', type=str, default='model1', choices=['model1', 'model2'], help='Type of model to use: model1 (default) or model2.')
85
+
86
+ args = parser.parse_args()
87
+
88
+ if not os.path.exists(args.save_dir):
89
+ os.makedirs(args.save_dir)
90
+
91
+ # Load the preprocessed data
92
+ data = np.load(args.data_file)
93
+ train_X, train_Y = data['train_X'], data['train_Y']
94
+ validate_X, validate_Y = data['validate_X'], data['validate_Y']
95
+
96
+ # Create the model based on the selected type
97
+ instance_shape = (train_X.shape[-1],)
98
+ max_length = train_X.shape[1]
99
+
100
+ if args.model_type == 'model2':
101
+ model = create_simple_model2(instance_shape, max_length)
102
+ else:
103
+ model = create_simple_model(instance_shape, max_length)
104
+
105
+ # Train the model
106
+ trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
107
+
108
+ # Final message after training and saving the model
109
+ print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")