Upload 19 files
Browse files- Dockerfile +26 -0
- ReadME.pdf +0 -0
- benchmark_train_inception.py +112 -0
- benchmark_train_resnet50.py +114 -0
- evaluate_on_test_cohort2.py +122 -0
- evaluate_risk_classifier.py +83 -0
- extract_omics_aligned_tiles_features.py +55 -0
- extract_tiles_from_wsi.py +37 -0
- make_dataset_for_benchmark_models.py +101 -0
- make_train_data_for_omics_plip.py +62 -0
- make_train_data_for_risk_classification.py +107 -0
- pre_process_tiles.py +44 -0
- requirements.txt +12 -0
- requirementsT.txt +217 -0
- train_GWSIF_classifier.py +134 -0
- train_and_evaluate_risk_classifier.py +144 -0
- train_omics_plip_model.py +89 -0
- train_risk_classifier.py +103 -0
- train_risk_classifier_optional.py +109 -0
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.8-slim-buster
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install system and Python dependencies
|
8 |
+
RUN apt-get update && \
|
9 |
+
apt-get install -y build-essential openslide-tools libgl1-mesa-glx && \
|
10 |
+
apt-get clean && \
|
11 |
+
rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
# Copy the entire genomic_plip_model directory contents into the container at /app
|
14 |
+
RUN adduser --disabled-password --gecos '' myuser
|
15 |
+
USER myuser
|
16 |
+
|
17 |
+
COPY ./ /app/
|
18 |
+
# Install Python dependencies
|
19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
20 |
+
|
21 |
+
# Create a non-root user and switch to it for security
|
22 |
+
|
23 |
+
EXPOSE 8888
|
24 |
+
|
25 |
+
# Set the entrypoint to a shell command
|
26 |
+
ENTRYPOINT ["/bin/bash"]
|
ReadME.pdf
ADDED
Binary file (68.2 kB). View file
|
|
benchmark_train_inception.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import json
|
5 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
|
6 |
+
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
|
7 |
+
import argparse
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
|
11 |
+
def compute_additional_metrics(generator, model):
|
12 |
+
y_true = generator.classes
|
13 |
+
y_pred_prob = model.predict(generator)
|
14 |
+
y_pred = np.argmax(y_pred_prob, axis=1)
|
15 |
+
auc = roc_auc_score(y_true, y_pred_prob[:, 1])
|
16 |
+
precision = precision_score(y_true, y_pred, average='macro')
|
17 |
+
recall = recall_score(y_true, y_pred, average='macro')
|
18 |
+
f1 = f1_score(y_true, y_pred, average='macro')
|
19 |
+
accuracy = accuracy_score(y_true, y_pred)
|
20 |
+
return auc, precision, recall, f1, accuracy, y_pred_prob
|
21 |
+
|
22 |
+
# Function to save evaluation metrics
|
23 |
+
def save_evaluation_metrics(generator, model, dataset_name, save_dir):
|
24 |
+
auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
|
25 |
+
metrics = {
|
26 |
+
'auc': auc,
|
27 |
+
'precision': precision,
|
28 |
+
'recall': recall,
|
29 |
+
'f1_score': f1,
|
30 |
+
'accuracy': accuracy
|
31 |
+
}
|
32 |
+
# Save predictions
|
33 |
+
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
|
34 |
+
return metrics
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
parser = argparse.ArgumentParser(description='Train and evaluate InceptionV3 on benchmark datasets.')
|
38 |
+
parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
|
39 |
+
parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
|
40 |
+
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
train_dir = os.path.join(args.dataset_dir, 'train')
|
45 |
+
validate_dir = os.path.join(args.dataset_dir, 'validate')
|
46 |
+
test_dir = os.path.join(args.dataset_dir, 'test')
|
47 |
+
test2_dir = os.path.join(args.dataset_dir, 'test2')
|
48 |
+
|
49 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
50 |
+
|
51 |
+
# Set up InceptionV3 model
|
52 |
+
with tf.device('GPU:0'):
|
53 |
+
inception = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet', input_shape=(299, 299, 3))
|
54 |
+
last_layer = inception.get_layer('mixed10')
|
55 |
+
last_output = last_layer.output
|
56 |
+
x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
|
57 |
+
x = tf.keras.layers.Dense(2, activation='softmax')(x) # Assuming binary classification
|
58 |
+
model = tf.keras.Model(inputs=inception.input, outputs=x)
|
59 |
+
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])
|
60 |
+
|
61 |
+
# Image data generators
|
62 |
+
train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
|
63 |
+
validate_datagen = IDG(rescale=1/255.0)
|
64 |
+
test_datagen = IDG(rescale=1/255.0)
|
65 |
+
|
66 |
+
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(299, 299),
|
67 |
+
class_mode='categorical', batch_size=64)
|
68 |
+
validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(299, 299),
|
69 |
+
class_mode='categorical', batch_size=64)
|
70 |
+
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(299, 299),
|
71 |
+
class_mode='categorical', batch_size=64)
|
72 |
+
test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(299, 299),
|
73 |
+
class_mode='categorical', batch_size=64)
|
74 |
+
|
75 |
+
# Training the model
|
76 |
+
hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)
|
77 |
+
|
78 |
+
# Save the trained model
|
79 |
+
model.save(os.path.join(args.save_dir, 'inception_model.hdf5'))
|
80 |
+
|
81 |
+
# Save training history separately
|
82 |
+
training_log = {
|
83 |
+
'loss': hist.history['loss'],
|
84 |
+
'val_loss': hist.history['val_loss'],
|
85 |
+
'accuracy': hist.history['accuracy'],
|
86 |
+
'val_accuracy': hist.history['val_accuracy'],
|
87 |
+
'recall': hist.history['recall'],
|
88 |
+
'val_recall': hist.history['val_recall'],
|
89 |
+
'precision': hist.history['precision'],
|
90 |
+
'val_precision': hist.history['val_precision']
|
91 |
+
}
|
92 |
+
with open(os.path.join(args.save_dir, 'training_log.json'), 'w') as f:
|
93 |
+
json.dump(training_log, f)
|
94 |
+
|
95 |
+
# Evaluate the model on each dataset and save metrics
|
96 |
+
train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
|
97 |
+
validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
|
98 |
+
test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
|
99 |
+
test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)
|
100 |
+
|
101 |
+
# Save the evaluation metrics in a JSON file
|
102 |
+
evaluation_metrics = {
|
103 |
+
'train_metrics': train_metrics,
|
104 |
+
'validate_metrics': validate_metrics,
|
105 |
+
'test_metrics': test_metrics,
|
106 |
+
'test2_metrics': test2_metrics
|
107 |
+
}
|
108 |
+
|
109 |
+
with open(os.path.join(args.save_dir, 'evaluation_metrics.json'), 'w') as f:
|
110 |
+
json.dump(evaluation_metrics, f)
|
111 |
+
|
112 |
+
print("Training and evaluation metrics saved.")
|
benchmark_train_resnet50.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import json
|
5 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
|
6 |
+
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
|
7 |
+
import argparse
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
|
11 |
+
def compute_additional_metrics(generator, model):
|
12 |
+
y_true = generator.classes
|
13 |
+
y_pred_prob = model.predict(generator)
|
14 |
+
y_pred = np.argmax(y_pred_prob, axis=1)
|
15 |
+
auc = roc_auc_score(y_true, y_pred_prob[:, 1])
|
16 |
+
precision = precision_score(y_true, y_pred, average='macro')
|
17 |
+
recall = recall_score(y_true, y_pred, average='macro')
|
18 |
+
f1 = f1_score(y_true, y_pred, average='macro')
|
19 |
+
accuracy = accuracy_score(y_true, y_pred)
|
20 |
+
return auc, precision, recall, f1, accuracy, y_pred_prob
|
21 |
+
|
22 |
+
# Function to save evaluation metrics
|
23 |
+
def save_evaluation_metrics(generator, model, dataset_name, save_dir):
|
24 |
+
auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
|
25 |
+
metrics = {
|
26 |
+
'auc': auc,
|
27 |
+
'precision': precision,
|
28 |
+
'recall': recall,
|
29 |
+
'f1_score': f1,
|
30 |
+
'accuracy': accuracy
|
31 |
+
}
|
32 |
+
# Save predictions
|
33 |
+
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
|
34 |
+
return metrics
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
parser = argparse.ArgumentParser(description='Train and evaluate ResNet50 on benchmark datasets.')
|
38 |
+
parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
|
39 |
+
parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
|
40 |
+
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
train_dir = os.path.join(args.dataset_dir, 'train')
|
45 |
+
validate_dir = os.path.join(args.dataset_dir, 'validate')
|
46 |
+
test_dir = os.path.join(args.dataset_dir, 'test')
|
47 |
+
test2_dir = os.path.join(args.dataset_dir, 'test2')
|
48 |
+
|
49 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
50 |
+
|
51 |
+
# Set up ResNet50 model
|
52 |
+
with tf.device('GPU:0'):
|
53 |
+
resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
|
54 |
+
last_layer = resnet.get_layer('conv5_block3_out')
|
55 |
+
last_output = last_layer.output
|
56 |
+
x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
|
57 |
+
x = tf.keras.layers.Dense(2, activation='softmax')(x) # Assuming binary classification
|
58 |
+
model = tf.keras.Model(inputs=resnet.input, outputs=x)
|
59 |
+
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])
|
60 |
+
|
61 |
+
# Image data generators
|
62 |
+
train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
|
63 |
+
validate_datagen = IDG(rescale=1/255.0)
|
64 |
+
test_datagen = IDG(rescale=1/255.0)
|
65 |
+
|
66 |
+
batch_size = 64
|
67 |
+
|
68 |
+
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
|
69 |
+
class_mode='categorical', batch_size=batch_size)
|
70 |
+
validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(224, 224),
|
71 |
+
class_mode='categorical', batch_size=batch_size)
|
72 |
+
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224),
|
73 |
+
class_mode='categorical', batch_size=batch_size)
|
74 |
+
test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(224, 224),
|
75 |
+
class_mode='categorical', batch_size=batch_size)
|
76 |
+
|
77 |
+
# Training the model
|
78 |
+
hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)
|
79 |
+
|
80 |
+
# Save the trained model
|
81 |
+
model.save(os.path.join(args.save_dir, 'risk_classifier_resnet_model.hdf5'))
|
82 |
+
|
83 |
+
# Save training history separately
|
84 |
+
training_log = {
|
85 |
+
'loss': hist.history['loss'],
|
86 |
+
'val_loss': hist.history['val_loss'],
|
87 |
+
'accuracy': hist.history['accuracy'],
|
88 |
+
'val_accuracy': hist.history['val_accuracy'],
|
89 |
+
'recall': hist.history['recall'],
|
90 |
+
'val_recall': hist.history['val_recall'],
|
91 |
+
'precision': hist.history['precision'],
|
92 |
+
'val_precision': hist.history['val_precision']
|
93 |
+
}
|
94 |
+
with open(os.path.join(args.save_dir, 'resnet_training_log.json'), 'w') as f:
|
95 |
+
json.dump(training_log, f)
|
96 |
+
|
97 |
+
# Evaluate the model on each dataset and save metrics
|
98 |
+
train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
|
99 |
+
validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
|
100 |
+
test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
|
101 |
+
test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)
|
102 |
+
|
103 |
+
# Save the evaluation metrics in a JSON file
|
104 |
+
evaluation_metrics = {
|
105 |
+
'train_metrics': train_metrics,
|
106 |
+
'validate_metrics': validate_metrics,
|
107 |
+
'test_metrics': test_metrics,
|
108 |
+
'test2_metrics': test2_metrics
|
109 |
+
}
|
110 |
+
|
111 |
+
with open(os.path.join(args.save_dir, 'resnet_evaluation_metrics.json'), 'w') as f:
|
112 |
+
json.dump(evaluation_metrics, f)
|
113 |
+
|
114 |
+
print("Training and evaluation metrics saved.")
|
evaluate_on_test_cohort2.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import json
|
6 |
+
import tensorflow as tf
|
7 |
+
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
# Function to load and preprocess the dataset
|
11 |
+
def load_and_preprocess_data(metadata_file, data_dir):
|
12 |
+
dff = pd.read_csv(metadata_file, skiprows=0)
|
13 |
+
if 'Unnamed: 0' in dff.columns:
|
14 |
+
del dff['Unnamed: 0']
|
15 |
+
|
16 |
+
# Filter and map classes to 0 and 1
|
17 |
+
classified_df = dff[dff['Class'].isin([1, 3])]
|
18 |
+
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
|
19 |
+
df = classified_df.set_index('PatientID')
|
20 |
+
|
21 |
+
# Filter for patients that have corresponding WSI data
|
22 |
+
available_patients = set(os.listdir(data_dir))
|
23 |
+
df = df.loc[df.index.intersection(available_patients)]
|
24 |
+
df = df.sample(frac=1)
|
25 |
+
|
26 |
+
return df
|
27 |
+
|
28 |
+
# Function to create bags of tiles
|
29 |
+
def create_bags(df, data_dir):
|
30 |
+
data = {'test2': {'X': [], 'Y': []}}
|
31 |
+
for pID, row in df.iterrows():
|
32 |
+
fol_p = os.path.join(data_dir, pID)
|
33 |
+
tiles = os.listdir(fol_p)
|
34 |
+
tile_data = []
|
35 |
+
for tile in tiles:
|
36 |
+
tile_p = os.path.join(fol_p, tile)
|
37 |
+
np1 = torch.load(tile_p).numpy()
|
38 |
+
tile_data.append(np1)
|
39 |
+
bag = np.squeeze(tile_data, axis=1)
|
40 |
+
bag_label = row['Class']
|
41 |
+
data['test2']['X'].append(bag)
|
42 |
+
data['test2']['Y'].append(np.array([bag_label]))
|
43 |
+
data['test2']['X'] = np.array(data['test2']['X'])
|
44 |
+
data['test2']['Y'] = np.array(data['test2']['Y'])
|
45 |
+
print(f"Data[test2]['X'] shape: {data['test2']['X'].shape}, dtype: {data['test2']['X'].dtype}")
|
46 |
+
return data
|
47 |
+
|
48 |
+
# Function to pad the data to ensure uniform bag length
|
49 |
+
def prepare_data_with_padding(data, max_length=2000):
|
50 |
+
padded_data = []
|
51 |
+
for bag in data:
|
52 |
+
if len(bag) < max_length:
|
53 |
+
padding = np.zeros((max_length - len(bag), bag.shape[1]))
|
54 |
+
padded_bag = np.vstack((bag, padding))
|
55 |
+
else:
|
56 |
+
padded_bag = bag
|
57 |
+
padded_data.append(padded_bag)
|
58 |
+
return np.array(padded_data)
|
59 |
+
|
60 |
+
# Function to compute additional metrics using sklearn
|
61 |
+
def compute_additional_metrics(X, Y, model):
|
62 |
+
predictions = model.predict(X).flatten()
|
63 |
+
predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
|
64 |
+
auc = roc_auc_score(Y, predictions)
|
65 |
+
precision = precision_score(Y, predictions_binary)
|
66 |
+
recall = recall_score(Y, predictions_binary)
|
67 |
+
f1 = f1_score(Y, predictions_binary)
|
68 |
+
return auc, precision, recall, f1, predictions
|
69 |
+
|
70 |
+
# Function to evaluate the model on a given dataset using sklearn metrics
|
71 |
+
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
|
72 |
+
# Evaluate using TensorFlow's model.evaluate() for loss and accuracy
|
73 |
+
eval_metrics = model.evaluate(X, Y, verbose=0)
|
74 |
+
|
75 |
+
# Compute additional metrics using sklearn
|
76 |
+
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
|
77 |
+
metrics = {
|
78 |
+
'loss': eval_metrics[0],
|
79 |
+
'accuracy': eval_metrics[1],
|
80 |
+
'auc': auc,
|
81 |
+
'precision': precision,
|
82 |
+
'recall': recall,
|
83 |
+
'f1_score': f1
|
84 |
+
}
|
85 |
+
|
86 |
+
# Save the predictions for each sample
|
87 |
+
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
|
88 |
+
|
89 |
+
return metrics
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
# Command line arguments
|
93 |
+
parser = argparse.ArgumentParser(description='Evaluate a trained model on a secondary test dataset (test2).')
|
94 |
+
parser.add_argument('--metadata_file', type=str, required=True, help='Path to the metadata CSV file for test2.')
|
95 |
+
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the extracted tissue features.')
|
96 |
+
parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model file.')
|
97 |
+
parser.add_argument('--save_dir', type=str, default='./evaluation_results_test2/', help='Directory to save evaluation results.')
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
if not os.path.exists(args.save_dir):
|
102 |
+
os.makedirs(args.save_dir)
|
103 |
+
|
104 |
+
# Load and preprocess the test2 data
|
105 |
+
df_test2 = load_and_preprocess_data(args.metadata_file, args.data_dir)
|
106 |
+
data_test2 = create_bags(df_test2, args.data_dir)
|
107 |
+
|
108 |
+
# Prepare the test2 data with padding
|
109 |
+
test2_X = prepare_data_with_padding(data_test2['test2']['X'], max_length=2000)
|
110 |
+
test2_Y = np.array(data_test2['test2']['Y']).flatten()
|
111 |
+
|
112 |
+
# Load the saved model
|
113 |
+
model = tf.keras.models.load_model(args.model_path)
|
114 |
+
|
115 |
+
# Evaluate the model on the test2 dataset
|
116 |
+
test2_metrics = evaluate_dataset(model, test2_X, test2_Y, "test2", args.save_dir)
|
117 |
+
|
118 |
+
# Save the metrics to a JSON file
|
119 |
+
with open(os.path.join(args.save_dir, 'evaluation_metrics_test2.json'), 'w') as f:
|
120 |
+
json.dump(test2_metrics, f, indent=4)
|
121 |
+
|
122 |
+
print("Evaluation metrics saved to evaluation_metrics_test2.json")
|
evaluate_risk_classifier.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
|
10 |
+
def compute_additional_metrics(X, Y, model):
|
11 |
+
predictions = model.predict(X).flatten()
|
12 |
+
predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
|
13 |
+
auc = roc_auc_score(Y, predictions)
|
14 |
+
precision = precision_score(Y, predictions_binary)
|
15 |
+
recall = recall_score(Y, predictions_binary)
|
16 |
+
f1 = f1_score(Y, predictions_binary)
|
17 |
+
return auc, precision, recall, f1, predictions
|
18 |
+
|
19 |
+
# Function to evaluate the model on a given dataset
|
20 |
+
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
|
21 |
+
eval_metrics = model.evaluate(X, Y, verbose=0)
|
22 |
+
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
|
23 |
+
metrics = {
|
24 |
+
'loss': eval_metrics[0],
|
25 |
+
'accuracy': eval_metrics[1],
|
26 |
+
'auc': auc,
|
27 |
+
'precision': precision,
|
28 |
+
'recall': recall,
|
29 |
+
'f1_score': f1
|
30 |
+
}
|
31 |
+
|
32 |
+
# Save the predictions for each sample
|
33 |
+
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
|
34 |
+
|
35 |
+
return metrics
|
36 |
+
|
37 |
+
# Function to evaluate the model on train, validate, and test datasets
|
38 |
+
def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
|
39 |
+
train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
|
40 |
+
validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
|
41 |
+
test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
|
42 |
+
|
43 |
+
metrics = {
|
44 |
+
'train': train_metrics,
|
45 |
+
'validate': validate_metrics,
|
46 |
+
'test': test_metrics
|
47 |
+
}
|
48 |
+
|
49 |
+
# Display the metrics in a tabular format
|
50 |
+
metrics_df = pd.DataFrame(metrics).T
|
51 |
+
print(metrics_df.to_string())
|
52 |
+
|
53 |
+
# Save metrics to a JSON file
|
54 |
+
with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
|
55 |
+
json.dump(metrics, f, indent=4)
|
56 |
+
|
57 |
+
print("Evaluation metrics saved to evaluation_metrics.json")
|
58 |
+
|
59 |
+
return metrics
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
# Command line arguments
|
63 |
+
parser = argparse.ArgumentParser(description='Evaluate a trained multiple instance learning classifier on risk data.')
|
64 |
+
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training, validation, and test data.')
|
65 |
+
parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model file.')
|
66 |
+
parser.add_argument('--save_dir', type=str, default='./evaluation_results/', help='Directory to save the evaluation results.')
|
67 |
+
|
68 |
+
args = parser.parse_args()
|
69 |
+
|
70 |
+
if not os.path.exists(args.save_dir):
|
71 |
+
os.makedirs(args.save_dir)
|
72 |
+
|
73 |
+
# Load the preprocessed data
|
74 |
+
data = np.load(args.data_file)
|
75 |
+
train_X, train_Y = data['train_X'], data['train_Y']
|
76 |
+
validate_X, validate_Y = data['validate_X'], data['validate_Y']
|
77 |
+
test_X, test_Y = data['test_X'], data['test_Y']
|
78 |
+
|
79 |
+
# Load the saved model
|
80 |
+
model = tf.keras.models.load_model(args.model_path)
|
81 |
+
|
82 |
+
# Evaluate the model
|
83 |
+
metrics = evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)
|
extract_omics_aligned_tiles_features.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from pathlib import Path
|
5 |
+
import argparse
|
6 |
+
from scripts.genomic_plip_model import GenomicPLIPModel
|
7 |
+
from transformers import CLIPVisionModel
|
8 |
+
|
9 |
+
class PatientTileDataset(Dataset):
|
10 |
+
def __init__(self, data_dir, model, save_dir):
|
11 |
+
super().__init__()
|
12 |
+
self.data_dir = data_dir
|
13 |
+
self.model = model
|
14 |
+
self.save_dir = Path(save_dir)
|
15 |
+
self.files = []
|
16 |
+
for patient_id in os.listdir(data_dir):
|
17 |
+
patient_dir = os.path.join(data_dir, patient_id)
|
18 |
+
if os.path.isdir(patient_dir):
|
19 |
+
for f in os.listdir(patient_dir):
|
20 |
+
if f.endswith('.pt'):
|
21 |
+
self.files.append((os.path.join(patient_dir, f), patient_id))
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.files)
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
file_path, patient_id = self.files[idx]
|
28 |
+
data = torch.load(file_path)
|
29 |
+
tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0) # Add batch dimension
|
30 |
+
with torch.no_grad():
|
31 |
+
vision_features, _ = self.model(pixel_values=tile_data, score_vector=torch.zeros(1, 4))
|
32 |
+
feature_path = self.save_dir / patient_id / os.path.basename(file_path)
|
33 |
+
feature_path.parent.mkdir(parents=True, exist_ok=True)
|
34 |
+
torch.save(vision_features, feature_path)
|
35 |
+
return feature_path
|
36 |
+
|
37 |
+
def extract_features(data_dir, save_dir, model_path):
|
38 |
+
original_model = CLIPVisionModel.from_pretrained("./plip/")
|
39 |
+
custom_model = GenomicPLIPModel(original_model)
|
40 |
+
custom_model.load_state_dict(torch.load(model_path))
|
41 |
+
custom_model.eval()
|
42 |
+
|
43 |
+
dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir)
|
44 |
+
for _ in dataset:
|
45 |
+
pass
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
parser = argparse.ArgumentParser(description="Extract features from genomic aligned tiles.")
|
49 |
+
parser.add_argument('--data_dir', type=str, default='plip_preprocess/', help='Directory containing the pre processed patient data.')
|
50 |
+
parser.add_argument('--save_dir', type=str, default='omics_align_features/', help='Directory to save the extracted features.')
|
51 |
+
parser.add_argument('--model_path', type=str, default='./save_model/omics_plip.pth', help='Path to the trained model file.')
|
52 |
+
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
extract_features(args.data_dir, args.save_dir, args.model_path)
|
extract_tiles_from_wsi.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from scripts.slide_processor_parallel import SlideProcessor
|
3 |
+
|
4 |
+
def main():
|
5 |
+
parser = argparse.ArgumentParser(description='Process whole slide images from a directory.')
|
6 |
+
|
7 |
+
# Required arguments
|
8 |
+
parser.add_argument('-d', '--directory', type=str, required=True,
|
9 |
+
help='Directory containing whole slide image files.')
|
10 |
+
parser.add_argument('-o', '--output_dir', type=str, required=True,
|
11 |
+
help='Directory to save the processed tiles.')
|
12 |
+
|
13 |
+
# Optional arguments with defaults
|
14 |
+
parser.add_argument('-t', '--tile_size', type=int, default=1024,
|
15 |
+
help='Size of the tile in pixels (default: 1024).')
|
16 |
+
parser.add_argument('-v', '--overlap', type=int, default=0,
|
17 |
+
help='Overlap of tiles in pixels (default: 0).')
|
18 |
+
parser.add_argument('-th', '--tissue_threshold', type=float, default=0.65,
|
19 |
+
help='Threshold for tissue detection as a float (default: 0.65).')
|
20 |
+
parser.add_argument('-w', '--max_workers', type=int, default=30,
|
21 |
+
help='Maximum number of worker threads/processes (default: 30).')
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
# Initialize the SlideProcessor with the parsed arguments
|
26 |
+
processor = SlideProcessor(
|
27 |
+
tile_size=args.tile_size,
|
28 |
+
overlap=args.overlap,
|
29 |
+
tissue_threshold=args.tissue_threshold,
|
30 |
+
max_workers=args.max_workers
|
31 |
+
)
|
32 |
+
|
33 |
+
# Start the processing
|
34 |
+
processor.parallel_process(base_dir=args.directory, output_dir=args.output_dir)
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
main()
|
make_dataset_for_benchmark_models.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
def load_and_preprocess_data(metadata_file, data_dir):
|
9 |
+
df = pd.read_csv(metadata_file, skiprows=0)
|
10 |
+
if 'Unnamed: 0' in df.columns:
|
11 |
+
del df['Unnamed: 0']
|
12 |
+
|
13 |
+
# Filter and map classes to 0 and 1
|
14 |
+
classified_df = df[df['Class'].isin([1, 3])]
|
15 |
+
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
|
16 |
+
df = classified_df.set_index('PatientID')
|
17 |
+
|
18 |
+
# Filter for patients that have corresponding WSI data
|
19 |
+
available_patients = set(os.listdir(data_dir))
|
20 |
+
df = df.loc[df.index.intersection(available_patients)]
|
21 |
+
df = df.sample(frac=1)
|
22 |
+
|
23 |
+
return df
|
24 |
+
|
25 |
+
def create_data_splits(df):
|
26 |
+
class1 = list(df[df['Class'] == 1].index)
|
27 |
+
class0 = list(df[df['Class'] == 0].index)
|
28 |
+
|
29 |
+
C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
|
30 |
+
C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2)
|
31 |
+
C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6)
|
32 |
+
C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5)
|
33 |
+
|
34 |
+
X_train = []; X_train.extend(C1_X_train); X_train.extend(C0_X_train)
|
35 |
+
X_test = []; X_test.extend(C1_X_test); X_test.extend(C0_X_test)
|
36 |
+
X_validate = []; X_validate.extend(C1_X_validate); X_validate.extend(C0_X_validate)
|
37 |
+
|
38 |
+
random.shuffle(X_train)
|
39 |
+
random.shuffle(X_test)
|
40 |
+
random.shuffle(X_validate)
|
41 |
+
|
42 |
+
data_info = {'train': X_train, 'test': X_test, 'validate': X_validate}
|
43 |
+
|
44 |
+
print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate)))
|
45 |
+
print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate)))
|
46 |
+
|
47 |
+
return data_info
|
48 |
+
|
49 |
+
def copy_tiles(patient_ids, dest_folder, source_dir, num_tiles_per_patient):
|
50 |
+
for pID in patient_ids:
|
51 |
+
flp = os.path.join(source_dir, pID)
|
52 |
+
if os.path.exists(flp):
|
53 |
+
tiles = os.listdir(flp)
|
54 |
+
selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
|
55 |
+
for tile in selected_tiles:
|
56 |
+
tile_p = os.path.join(flp, tile)
|
57 |
+
new_p = os.path.join(dest_folder, tile)
|
58 |
+
shutil.copy(tile_p, new_p)
|
59 |
+
else:
|
60 |
+
print(f"Folder not found for patient {pID}")
|
61 |
+
|
62 |
+
def process_cohorts(primary_metadata, secondary_metadata, source_dir, dataset_dir, num_tiles_per_patient):
|
63 |
+
# Create necessary directories if they don't exist
|
64 |
+
os.makedirs(os.path.join(dataset_dir, 'train/class1/'), exist_ok=True)
|
65 |
+
os.makedirs(os.path.join(dataset_dir, 'train/class0/'), exist_ok=True)
|
66 |
+
os.makedirs(os.path.join(dataset_dir, 'test/class1/'), exist_ok=True)
|
67 |
+
os.makedirs(os.path.join(dataset_dir, 'test/class0/'), exist_ok=True)
|
68 |
+
os.makedirs(os.path.join(dataset_dir, 'validate/class1/'), exist_ok=True)
|
69 |
+
os.makedirs(os.path.join(dataset_dir, 'validate/class0/'), exist_ok=True)
|
70 |
+
os.makedirs(os.path.join(dataset_dir, 'test2/class1/'), exist_ok=True)
|
71 |
+
os.makedirs(os.path.join(dataset_dir, 'test2/class0/'), exist_ok=True)
|
72 |
+
|
73 |
+
# Load and preprocess primary cohort
|
74 |
+
primary_df = load_and_preprocess_data(primary_metadata, source_dir)
|
75 |
+
primary_data_info = create_data_splits(primary_df)
|
76 |
+
|
77 |
+
# Load and preprocess secondary cohort
|
78 |
+
secondary_df = load_and_preprocess_data(secondary_metadata, source_dir)
|
79 |
+
secondary_data_info = {'test2': secondary_df.index.tolist()}
|
80 |
+
|
81 |
+
# Copy tiles for the primary cohort
|
82 |
+
copy_tiles(primary_data_info['train'], os.path.join(dataset_dir, 'train/class1/'), source_dir, num_tiles_per_patient)
|
83 |
+
copy_tiles(primary_data_info['test'], os.path.join(dataset_dir, 'test/class1/'), source_dir, num_tiles_per_patient)
|
84 |
+
copy_tiles(primary_data_info['validate'], os.path.join(dataset_dir, 'validate/class1/'), source_dir, num_tiles_per_patient)
|
85 |
+
|
86 |
+
# Copy tiles for the secondary cohort
|
87 |
+
copy_tiles(secondary_data_info['test2'], os.path.join(dataset_dir, 'test2/class1/'), source_dir, num_tiles_per_patient)
|
88 |
+
|
89 |
+
print("Tiles copying completed for both cohorts.")
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
parser = argparse.ArgumentParser(description='Create dataset for benchmark models from primary and secondary cohorts.')
|
93 |
+
parser.add_argument('--primary_metadata', type=str, required=True, help='Path to the primary cohort metadata CSV file.')
|
94 |
+
parser.add_argument('--secondary_metadata', type=str, required=True, help='Path to the secondary cohort metadata CSV file.')
|
95 |
+
parser.add_argument('--source_dir', type=str, required=True, help='Directory containing raw tissue tiles.')
|
96 |
+
parser.add_argument('--dataset_dir', type=str, required=True, help='Directory to save the processed dataset.')
|
97 |
+
parser.add_argument('--num_tiles_per_patient', type=int, default=595, help='Number of tiles to select per patient.')
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
process_cohorts(args.primary_metadata, args.secondary_metadata, args.source_dir, args.dataset_dir, args.num_tiles_per_patient)
|
make_train_data_for_omics_plip.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
def main(num_tiles_per_patient, source_dir, dataset_dir, test_val_size, val_size):
|
8 |
+
# Create necessary directories if they don't exist
|
9 |
+
os.makedirs(os.path.join(dataset_dir, 'train'), exist_ok=True)
|
10 |
+
os.makedirs(os.path.join(dataset_dir, 'test'), exist_ok=True)
|
11 |
+
os.makedirs(os.path.join(dataset_dir, 'validate'), exist_ok=True)
|
12 |
+
|
13 |
+
|
14 |
+
with open('./data/tcga-hnscc-patients.json', 'r') as f:
|
15 |
+
patient_data = json.load(f)
|
16 |
+
|
17 |
+
# Separate the patients into study and restricted groups
|
18 |
+
study_patients = patient_data['study']
|
19 |
+
restricted_patients = patient_data['restricted']
|
20 |
+
|
21 |
+
# List all patient directories in the source directory
|
22 |
+
files = os.listdir(source_dir)
|
23 |
+
|
24 |
+
# Filter files based on study patients
|
25 |
+
study_files = [file for file in files if file in study_patients]
|
26 |
+
|
27 |
+
# Split the data into train, test, and validation sets
|
28 |
+
train, test_val = train_test_split(files, test_size=test_val_size)
|
29 |
+
test, val = train_test_split(test_val, test_size=val_size)
|
30 |
+
|
31 |
+
# Function to process and copy files
|
32 |
+
def process_and_copy(file_list, type):
|
33 |
+
for file in file_list:
|
34 |
+
fol_p = os.path.join(source_dir, file)
|
35 |
+
tiles = os.listdir(fol_p)
|
36 |
+
selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
|
37 |
+
for tile in selected_tiles:
|
38 |
+
tile_p = os.path.join(fol_p, tile)
|
39 |
+
new_p = os.path.join(dataset_dir, type, tile)
|
40 |
+
shutil.copy(tile_p, new_p)
|
41 |
+
|
42 |
+
# Process and copy files for each dataset
|
43 |
+
process_and_copy(train, 'train')
|
44 |
+
process_and_copy(test, 'test')
|
45 |
+
process_and_copy(val, 'validate')
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
parser = argparse.ArgumentParser(description='Split data into train, test, and validation sets.')
|
49 |
+
parser.add_argument('--num_tiles_per_patient', type=int, default=595,
|
50 |
+
help='Number of tiles to select per patient.')
|
51 |
+
parser.add_argument('--source_dir', type=str, default='plip_preprocess',
|
52 |
+
help='Directory containing patient folders.')
|
53 |
+
parser.add_argument('--dataset_dir', type=str, default='Datasets/train_03',
|
54 |
+
help='Root directory for the train, test, and validate directories.')
|
55 |
+
parser.add_argument('--test_val_size', type=float, default=0.4,
|
56 |
+
help='Size of the test and validation sets combined.')
|
57 |
+
parser.add_argument('--val_size', type=float, default=0.5,
|
58 |
+
help='Proportion of validation set in the test-validation split.')
|
59 |
+
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
main(args.num_tiles_per_patient, args.source_dir, args.dataset_dir, args.test_val_size, args.val_size)
|
make_train_data_for_risk_classification.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
def prepare_data_with_padding(data, max_length=None):
|
10 |
+
if max_length is None:
|
11 |
+
max_length = max(len(bag) for bag in data)
|
12 |
+
padded_data = []
|
13 |
+
for bag in data:
|
14 |
+
if len(bag) < max_length:
|
15 |
+
padding = np.zeros((max_length - len(bag), bag.shape[1]))
|
16 |
+
padded_bag = np.vstack((bag, padding))
|
17 |
+
else:
|
18 |
+
padded_bag = bag
|
19 |
+
padded_data.append(padded_bag)
|
20 |
+
return np.array(padded_data), max_length
|
21 |
+
|
22 |
+
def create_bags(data_info, df13, data_dir):
|
23 |
+
data = {'train': {'X': [], 'Y': []}, 'test': {'X': [], 'Y': []}, 'validate': {'X': [], 'Y': []}}
|
24 |
+
for split in ['train', 'test', 'validate']:
|
25 |
+
for pID in data_info[split]:
|
26 |
+
fol_p = os.path.join(data_dir, pID)
|
27 |
+
tiles = os.listdir(fol_p)
|
28 |
+
tile_data = []
|
29 |
+
for tile in tiles:
|
30 |
+
tile_p = os.path.join(fol_p, tile)
|
31 |
+
np1 = torch.load(tile_p).numpy()
|
32 |
+
tile_data.append(np1)
|
33 |
+
patient_label = df13.loc[pID, 'Class']
|
34 |
+
bag = np.squeeze(tile_data, axis=1)
|
35 |
+
bag_label = 1 if patient_label == 1 else 0
|
36 |
+
data[split]['X'].append(bag)
|
37 |
+
data[split]['Y'].append(np.array([bag_label]))
|
38 |
+
data[split]['X'] = np.array(data[split]['X'])
|
39 |
+
data[split]['Y'] = np.array(data[split]['Y'])
|
40 |
+
print(f"Data[{split}]['X'] shape: {data[split]['X'].shape}, dtype: {data[split]['X'].dtype}")
|
41 |
+
return data
|
42 |
+
|
43 |
+
def process_and_save(data_dir, metadata_file, save_dir):
|
44 |
+
# Load and preprocess metadata
|
45 |
+
dff = pd.read_csv(metadata_file, skiprows=0)
|
46 |
+
del dff['Unnamed: 0']
|
47 |
+
|
48 |
+
classified_df = dff[dff['Class'].isin([1, 3])]
|
49 |
+
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
|
50 |
+
df13 = classified_df.set_index('PatientID')
|
51 |
+
|
52 |
+
there = set(list(df13.index))
|
53 |
+
wsi_there = os.listdir(data_dir)
|
54 |
+
use = list(there.intersection(wsi_there))
|
55 |
+
df13 = df13.loc[use]
|
56 |
+
df13 = df13.sample(frac=1)
|
57 |
+
|
58 |
+
class1 = list(df13[df13['Class'] == 1].index)
|
59 |
+
class0 = list(df13[df13['Class'] == 0].index)
|
60 |
+
|
61 |
+
C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
|
62 |
+
C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2)
|
63 |
+
C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6)
|
64 |
+
C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5)
|
65 |
+
|
66 |
+
X_train = C1_X_train + C0_X_train
|
67 |
+
X_test = C1_X_test + C0_X_test
|
68 |
+
X_validate = C1_X_validate + C0_X_validate
|
69 |
+
|
70 |
+
random.shuffle(X_train)
|
71 |
+
random.shuffle(X_test)
|
72 |
+
random.shuffle(X_validate)
|
73 |
+
|
74 |
+
data_info = {'train': X_train, 'test': X_test, 'validate': X_validate}
|
75 |
+
|
76 |
+
print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate)))
|
77 |
+
print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate)))
|
78 |
+
|
79 |
+
# Create bags and prepare data with padding
|
80 |
+
data = create_bags(data_info, df13, data_dir)
|
81 |
+
train_X, _ = prepare_data_with_padding(data['train']['X'], 2000)
|
82 |
+
train_Y = np.array(data['train']['Y']).flatten()
|
83 |
+
validate_X, _ = prepare_data_with_padding(data['validate']['X'], 2000)
|
84 |
+
validate_Y = np.array(data['validate']['Y']).flatten()
|
85 |
+
test_X, _ = prepare_data_with_padding(data['test']['X'], 2000)
|
86 |
+
test_Y = np.array(data['test']['Y']).flatten()
|
87 |
+
|
88 |
+
# Save the processed arrays to a single file
|
89 |
+
np.savez_compressed(os.path.join(save_dir, 'training_risk_classifier_data.npz'),
|
90 |
+
train_X=train_X, train_Y=train_Y,
|
91 |
+
validate_X=validate_X, validate_Y=validate_Y,
|
92 |
+
test_X=test_X, test_Y=test_Y)
|
93 |
+
|
94 |
+
print("Data saved successfully in:", os.path.join(save_dir, 'training_risk_classifier_data.npz'))
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
parser = argparse.ArgumentParser(description='Process, split, and save the data with padding.')
|
98 |
+
parser.add_argument('--data_dir', type=str, help='Directory containing the extracted features.')
|
99 |
+
parser.add_argument('--metadata_file', type=str, default='data/data1.hnsc.p3.csv', help='CSV file containing the metadata for the samples.')
|
100 |
+
parser.add_argument('--save_dir', type=str, default='Datasets', help='Directory to save the processed data.')
|
101 |
+
|
102 |
+
args = parser.parse_args()
|
103 |
+
|
104 |
+
if not os.path.exists(args.save_dir):
|
105 |
+
os.makedirs(args.save_dir)
|
106 |
+
|
107 |
+
process_and_save(args.data_dir, args.metadata_file, args.save_dir)
|
pre_process_tiles.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
from scripts.PlipDataProcess import PlipDataProcess # Updated folder name
|
4 |
+
from transformers import CLIPImageProcessor
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
def main(csv_file, root_dir, save_dir):
|
8 |
+
# Load the CSV file and set 'PatientID' as the index
|
9 |
+
df4 = pd.read_csv(csv_file).set_index('PatientID')
|
10 |
+
|
11 |
+
# List directories in the root directory (assuming each directory corresponds to a patient)
|
12 |
+
files = [file for file in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, file))]
|
13 |
+
|
14 |
+
# Initialize the image processor
|
15 |
+
img_processor = CLIPImageProcessor.from_pretrained("./plip/")
|
16 |
+
|
17 |
+
# Initialize the dataset processing object
|
18 |
+
dataset = PlipDataProcess(
|
19 |
+
root_dir=root_dir,
|
20 |
+
files=files,
|
21 |
+
df=df4,
|
22 |
+
img_processor=img_processor,
|
23 |
+
num_tiles_per_patient=2000,
|
24 |
+
max_workers=64,
|
25 |
+
save_dir=save_dir
|
26 |
+
)
|
27 |
+
|
28 |
+
# Process each item in the dataset
|
29 |
+
for i in range(len(dataset)):
|
30 |
+
_ = dataset[i] # Trigger processing of the i-th item
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
parser = argparse.ArgumentParser(description="Process WSI images and generate tiles")
|
34 |
+
|
35 |
+
# Define arguments
|
36 |
+
parser.add_argument('--csv_file', type=str, required=True, help='Path to the CSV file with patient scores')
|
37 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for WSI tiles')
|
38 |
+
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the processed tile data')
|
39 |
+
|
40 |
+
# Parse arguments
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
# Call the main function with the parsed arguments
|
44 |
+
main(csv_file=args.csv_file, root_dir=args.root_dir, save_dir=args.save_dir)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.19.2
|
2 |
+
pandas==1.3.4
|
3 |
+
matplotlib==3.5.2
|
4 |
+
openslide-python==1.1.2
|
5 |
+
scikit-image==0.18.1
|
6 |
+
scikit-learn==1.2.1
|
7 |
+
tqdm==4.62.3
|
8 |
+
Pillow==9.4.0
|
9 |
+
transformers==4.33.2
|
10 |
+
torch==2.0.1
|
11 |
+
jupyterlab==3.2.1
|
12 |
+
tensorflow==2.6.1
|
requirementsT.txt
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
aiohttp==3.9.5
|
3 |
+
aiosignal==1.3.1
|
4 |
+
anndata==0.9.1
|
5 |
+
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
|
6 |
+
appdirs==1.4.4
|
7 |
+
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
|
8 |
+
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1649500328244/work
|
9 |
+
astor==0.8.1
|
10 |
+
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
|
11 |
+
astunparse==1.6.3
|
12 |
+
async-timeout==4.0.3
|
13 |
+
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
|
14 |
+
autograd==1.5
|
15 |
+
autograd-gamma==0.5.0
|
16 |
+
Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work
|
17 |
+
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
|
18 |
+
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
|
19 |
+
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1680888073205/work
|
20 |
+
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
|
21 |
+
Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work
|
22 |
+
brotlipy==0.7.0
|
23 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1720457958366/work/certifi
|
24 |
+
cffi @ file:///croot/cffi_1670423208954/work
|
25 |
+
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
|
26 |
+
cmake==3.27.2
|
27 |
+
contourpy @ file:///opt/conda/conda-bld/contourpy_1663827406301/work
|
28 |
+
cryptography @ file:///croot/cryptography_1677533068310/work
|
29 |
+
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
|
30 |
+
datasets==2.19.1
|
31 |
+
debugpy @ file:///home/builder/ci_310/debugpy_1640789504635/work
|
32 |
+
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
|
33 |
+
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
|
34 |
+
dill==0.3.8
|
35 |
+
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
|
36 |
+
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
|
37 |
+
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1677336799617/work/dist
|
38 |
+
filelock @ file:///croot/filelock_1672387128942/work
|
39 |
+
flatbuffers==24.3.25
|
40 |
+
flit_core @ file:///croot/flit-core_1679397103445/work/source/flit_core
|
41 |
+
fonttools==4.25.0
|
42 |
+
formulaic==0.5.2
|
43 |
+
frozenlist==1.4.1
|
44 |
+
fsspec==2023.6.0
|
45 |
+
future==0.18.3
|
46 |
+
gast==0.5.4
|
47 |
+
git-filter-repo==2.38.0
|
48 |
+
gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
|
49 |
+
google-pasta==0.2.0
|
50 |
+
graphviz==0.20.1
|
51 |
+
grpcio==1.64.0
|
52 |
+
h5py==3.11.0
|
53 |
+
huggingface-hub==0.23.2
|
54 |
+
idna @ file:///croot/idna_1666125576474/work
|
55 |
+
imageio==2.34.2
|
56 |
+
imbalanced-learn==0.11.0
|
57 |
+
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1680895625127/work
|
58 |
+
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1676919000169/work
|
59 |
+
interface-meta==1.3.0
|
60 |
+
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work
|
61 |
+
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work
|
62 |
+
ipython-genutils==0.2.0
|
63 |
+
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
|
64 |
+
Jinja2 @ file:///croot/jinja2_1666908132255/work
|
65 |
+
joblib==1.4.2
|
66 |
+
json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1600692310011/work
|
67 |
+
jsonpickle==3.0.1
|
68 |
+
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
|
69 |
+
jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
|
70 |
+
jupyter-server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1676473377907/work
|
71 |
+
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1669775088561/work
|
72 |
+
jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1674494302491/work
|
73 |
+
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
|
74 |
+
jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1680275157923/work
|
75 |
+
keras==3.3.3
|
76 |
+
keras-tuner==1.4.7
|
77 |
+
kiwisolver @ file:///croot/kiwisolver_1672387140495/work
|
78 |
+
kt-legacy==1.0.5
|
79 |
+
lazy_loader==0.4
|
80 |
+
libclang==18.1.1
|
81 |
+
lifelines==0.27.4
|
82 |
+
lit==16.0.6
|
83 |
+
llvmlite==0.40.0
|
84 |
+
Markdown==3.6
|
85 |
+
markdown-it-py==3.0.0
|
86 |
+
MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
|
87 |
+
matplotlib @ file:///croot/matplotlib-suite_1679593461707/work
|
88 |
+
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
|
89 |
+
matplotlib-venn==0.11.9
|
90 |
+
mdurl==0.1.2
|
91 |
+
mil==1.0.5
|
92 |
+
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
|
93 |
+
mkl-fft==1.3.1
|
94 |
+
mkl-random @ file:///home/builder/ci_310/mkl_random_1641843545607/work
|
95 |
+
mkl-service==2.4.0
|
96 |
+
ml-dtypes==0.3.2
|
97 |
+
mpmath==1.2.1
|
98 |
+
multidict==6.0.5
|
99 |
+
multiprocess==0.70.16
|
100 |
+
munkres==1.1.4
|
101 |
+
namex==0.0.8
|
102 |
+
natsort==8.3.1
|
103 |
+
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1680699279518/work
|
104 |
+
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1680676954923/work
|
105 |
+
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1680629662454/work
|
106 |
+
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1679336765223/work
|
107 |
+
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
|
108 |
+
networkx @ file:///croot/networkx_1678964333703/work
|
109 |
+
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1680870634737/work
|
110 |
+
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
|
111 |
+
numba==0.57.0
|
112 |
+
numexpr @ file:///croot/numexpr_1668713893690/work
|
113 |
+
numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work
|
114 |
+
nvidia-cublas-cu11==11.10.3.66
|
115 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
116 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
117 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
118 |
+
nvidia-cudnn-cu11==8.5.0.96
|
119 |
+
nvidia-cufft-cu11==10.9.0.58
|
120 |
+
nvidia-curand-cu11==10.2.10.91
|
121 |
+
nvidia-cusolver-cu11==11.4.0.1
|
122 |
+
nvidia-cusparse-cu11==11.7.4.91
|
123 |
+
nvidia-nccl-cu11==2.14.3
|
124 |
+
nvidia-nvtx-cu11==11.7.91
|
125 |
+
opencv-python==4.8.0.76
|
126 |
+
openslide-python==1.1.2
|
127 |
+
opt-einsum==3.3.0
|
128 |
+
optree==0.11.0
|
129 |
+
packaging @ file:///croot/packaging_1678965309396/work
|
130 |
+
pandas @ file:///croot/pandas_1692289311655/work
|
131 |
+
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
|
132 |
+
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
|
133 |
+
patsy==0.5.3
|
134 |
+
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
|
135 |
+
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
|
136 |
+
Pillow==9.4.0
|
137 |
+
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
|
138 |
+
plotly==5.14.1
|
139 |
+
ply==3.11
|
140 |
+
pooch @ file:///tmp/build/80754af9/pooch_1623324770023/work
|
141 |
+
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
|
142 |
+
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work
|
143 |
+
protobuf==4.25.3
|
144 |
+
psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
|
145 |
+
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
|
146 |
+
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
|
147 |
+
pyarrow==16.1.0
|
148 |
+
pyarrow-hotfix==0.6
|
149 |
+
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
|
150 |
+
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
|
151 |
+
pykan==0.0.2
|
152 |
+
pynndescent==0.5.10
|
153 |
+
pyOpenSSL @ file:///croot/pyopenssl_1677607685877/work
|
154 |
+
pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
|
155 |
+
PyQt5-sip==12.11.0
|
156 |
+
pyrsistent @ file:///home/builder/ci_310/pyrsistent_1640807196327/work
|
157 |
+
PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
|
158 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
|
159 |
+
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work
|
160 |
+
pyvis==0.3.2
|
161 |
+
PyYAML==6.0.1
|
162 |
+
pyzmq @ file:///opt/conda/conda-bld/pyzmq_1657724186960/work
|
163 |
+
regex==2023.8.8
|
164 |
+
requests @ file:///croot/requests_1678709721434/work
|
165 |
+
rich==13.7.1
|
166 |
+
rpy2==3.5.11
|
167 |
+
safetensors==0.3.3
|
168 |
+
scanpy==1.9.3
|
169 |
+
scikit-image==0.24.0
|
170 |
+
scikit-learn==1.5.1
|
171 |
+
scipy==1.14.0
|
172 |
+
seaborn @ file:///croot/seaborn_1673479180098/work
|
173 |
+
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
|
174 |
+
session-info==1.0.0
|
175 |
+
sip @ file:///tmp/abs_44cd77b_pu/croots/recipe/sip_1659012365470/work
|
176 |
+
six @ file:///tmp/build/80754af9/six_1644875935023/work
|
177 |
+
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
|
178 |
+
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
|
179 |
+
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
|
180 |
+
statsmodels==0.14.0
|
181 |
+
stdlib-list==0.8.0
|
182 |
+
sympy @ file:///croot/sympy_1668202399572/work
|
183 |
+
tenacity==8.2.2
|
184 |
+
tensorboard==2.16.2
|
185 |
+
tensorboard-data-server==0.7.2
|
186 |
+
tensorflow==2.16.1
|
187 |
+
tensorflow-io-gcs-filesystem==0.37.0
|
188 |
+
termcolor==2.4.0
|
189 |
+
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
|
190 |
+
threadpoolctl==3.5.0
|
191 |
+
tifffile==2024.7.24
|
192 |
+
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
|
193 |
+
tokenizers==0.13.3
|
194 |
+
toml @ file:///tmp/build/80754af9/toml_1616166611790/work
|
195 |
+
tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
|
196 |
+
torch==2.0.0
|
197 |
+
torch-geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1679554663466/work
|
198 |
+
torchvision==0.15.2
|
199 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827254365/work
|
200 |
+
tqdm @ file:///croot/tqdm_1679561862951/work
|
201 |
+
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
|
202 |
+
transformers==4.32.1
|
203 |
+
triton==2.0.0
|
204 |
+
typing_extensions @ file:///croot/typing_extensions_1669924550328/work
|
205 |
+
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1707747584337/work
|
206 |
+
tzlocal==5.0.1
|
207 |
+
umap-learn==0.5.3
|
208 |
+
urllib3 @ file:///croot/urllib3_1680254681959/work
|
209 |
+
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
|
210 |
+
webencodings==0.5.1
|
211 |
+
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
|
212 |
+
Werkzeug==3.0.3
|
213 |
+
wrapt==1.15.0
|
214 |
+
xxhash==3.4.1
|
215 |
+
yarl==1.9.4
|
216 |
+
yellowbrick==1.5
|
217 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1677313463193/work
|
train_GWSIF_classifier.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers, Model
|
5 |
+
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
|
9 |
+
# Define the function to create the multiple instance learning (MIL) model
|
10 |
+
def create_simple_model(instance_shape, max_length):
|
11 |
+
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
|
12 |
+
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
|
13 |
+
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
|
14 |
+
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
|
15 |
+
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
|
16 |
+
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
|
17 |
+
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
|
18 |
+
norm_1 = layers.LayerNormalization()(aggregated)
|
19 |
+
output = layers.Dense(1, activation="sigmoid")(norm_1)
|
20 |
+
return Model(inputs, output)
|
21 |
+
|
22 |
+
# Function to compute class weights
|
23 |
+
def compute_class_weights(labels):
|
24 |
+
negative_count = len(np.where(labels == 0)[0])
|
25 |
+
positive_count = len(np.where(labels == 1)[0])
|
26 |
+
total_count = negative_count + positive_count
|
27 |
+
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
|
28 |
+
|
29 |
+
# Function to generate batches of data
|
30 |
+
def data_generator(data, labels, batch_size=1):
|
31 |
+
class_weights = compute_class_weights(labels)
|
32 |
+
while True:
|
33 |
+
for i in range(0, len(data), batch_size):
|
34 |
+
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
|
35 |
+
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
|
36 |
+
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
|
37 |
+
yield batch_data, batch_labels, batch_weights
|
38 |
+
|
39 |
+
# Learning rate scheduler
|
40 |
+
def lr_scheduler(epoch, lr):
|
41 |
+
decay_rate = 0.1
|
42 |
+
decay_step = 10
|
43 |
+
if epoch % decay_step == 0 and epoch:
|
44 |
+
return lr * decay_rate
|
45 |
+
return lr
|
46 |
+
|
47 |
+
# Function to train the model
|
48 |
+
def train(train_data, train_labels, val_data, val_labels, model):
|
49 |
+
file_path = "/tmp/best_model.weights.h5"
|
50 |
+
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(file_path, monitor="val_loss", verbose=0, mode="min", save_best_only=True, save_weights_only=True)
|
51 |
+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
|
52 |
+
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
|
53 |
+
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
|
54 |
+
train_gen = data_generator(train_data, train_labels)
|
55 |
+
val_gen = data_generator(val_data, val_labels)
|
56 |
+
model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
|
57 |
+
model.load_weights(file_path)
|
58 |
+
return model
|
59 |
+
|
60 |
+
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
|
61 |
+
def compute_additional_metrics(X, Y, model):
|
62 |
+
predictions = model.predict(X).flatten()
|
63 |
+
predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
|
64 |
+
auc = roc_auc_score(Y, predictions)
|
65 |
+
precision = precision_score(Y, predictions_binary)
|
66 |
+
recall = recall_score(Y, predictions_binary)
|
67 |
+
f1 = f1_score(Y, predictions_binary)
|
68 |
+
return auc, precision, recall, f1, predictions
|
69 |
+
|
70 |
+
# Function to evaluate the model on a given dataset
|
71 |
+
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
|
72 |
+
eval_metrics = model.evaluate(X, Y, verbose=0)
|
73 |
+
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
|
74 |
+
metrics = {
|
75 |
+
'loss': eval_metrics[0],
|
76 |
+
'accuracy': eval_metrics[1],
|
77 |
+
'auc': auc,
|
78 |
+
'precision': precision,
|
79 |
+
'recall': recall,
|
80 |
+
'f1_score': f1
|
81 |
+
}
|
82 |
+
|
83 |
+
# Save the predictions for each sample
|
84 |
+
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
|
85 |
+
|
86 |
+
return metrics
|
87 |
+
|
88 |
+
# Function to evaluate the model on train, validate, and test datasets
|
89 |
+
def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
|
90 |
+
train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
|
91 |
+
validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
|
92 |
+
test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
|
93 |
+
|
94 |
+
metrics = {
|
95 |
+
'train': train_metrics,
|
96 |
+
'validate': validate_metrics,
|
97 |
+
'test': test_metrics
|
98 |
+
}
|
99 |
+
|
100 |
+
with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
|
101 |
+
json.dump(metrics, f, indent=4)
|
102 |
+
|
103 |
+
print("Evaluation metrics saved to evaluation_metrics.json")
|
104 |
+
|
105 |
+
return metrics
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
# Command line arguments
|
109 |
+
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
|
110 |
+
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
|
111 |
+
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model and evaluation metrics.')
|
112 |
+
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
|
113 |
+
|
114 |
+
args = parser.parse_args()
|
115 |
+
|
116 |
+
if not os.path.exists(args.save_dir):
|
117 |
+
os.makedirs(args.save_dir)
|
118 |
+
|
119 |
+
# Load the preprocessed data
|
120 |
+
data = np.load(args.data_file)
|
121 |
+
train_X, train_Y = data['train_X'], data['train_Y']
|
122 |
+
validate_X, validate_Y = data['validate_X'], data['validate_Y']
|
123 |
+
test_X, test_Y = data['test_X'], data['test_Y']
|
124 |
+
|
125 |
+
# Create the model
|
126 |
+
instance_shape = (train_X.shape[-1],)
|
127 |
+
max_length = train_X.shape[1]
|
128 |
+
model = create_simple_model(instance_shape, max_length)
|
129 |
+
|
130 |
+
# Train the model
|
131 |
+
trained_model = train(train_X, train_Y, validate_X, validate_Y, model)
|
132 |
+
|
133 |
+
# Evaluate the model
|
134 |
+
metrics = evaluate_all_datasets(trained_model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)
|
train_and_evaluate_risk_classifier.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers, Model
|
5 |
+
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
# Define the function to create the multiple instance learning (MIL) model
|
11 |
+
def create_simple_model(instance_shape, max_length):
|
12 |
+
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
|
13 |
+
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
|
14 |
+
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
|
15 |
+
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
|
16 |
+
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
|
17 |
+
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
|
18 |
+
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
|
19 |
+
norm_1 = layers.LayerNormalization()(aggregated)
|
20 |
+
output = layers.Dense(1, activation="sigmoid")(norm_1)
|
21 |
+
return Model(inputs, output)
|
22 |
+
|
23 |
+
# Function to compute class weights
|
24 |
+
def compute_class_weights(labels):
|
25 |
+
negative_count = len(np.where(labels == 0)[0])
|
26 |
+
positive_count = len(np.where(labels == 1)[0])
|
27 |
+
total_count = negative_count + positive_count
|
28 |
+
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
|
29 |
+
|
30 |
+
# Function to generate batches of data
|
31 |
+
def data_generator(data, labels, batch_size=1):
|
32 |
+
class_weights = compute_class_weights(labels)
|
33 |
+
while True:
|
34 |
+
for i in range(0, len(data), batch_size):
|
35 |
+
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
|
36 |
+
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
|
37 |
+
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
|
38 |
+
yield batch_data, batch_labels, batch_weights
|
39 |
+
|
40 |
+
# Learning rate scheduler
|
41 |
+
def lr_scheduler(epoch, lr):
|
42 |
+
decay_rate = 0.1
|
43 |
+
decay_step = 10
|
44 |
+
if epoch % decay_step == 0 and epoch:
|
45 |
+
return lr * decay_rate
|
46 |
+
return lr
|
47 |
+
|
48 |
+
# Function to train the model
|
49 |
+
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
|
50 |
+
model_path = os.path.join(save_dir, "best_model.h5")
|
51 |
+
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
|
52 |
+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
|
53 |
+
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
|
54 |
+
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
|
55 |
+
train_gen = data_generator(train_data, train_labels)
|
56 |
+
val_gen = data_generator(val_data, val_labels)
|
57 |
+
model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
|
58 |
+
return model
|
59 |
+
|
60 |
+
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
|
61 |
+
def compute_additional_metrics(X, Y, model):
|
62 |
+
predictions = model.predict(X).flatten()
|
63 |
+
predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
|
64 |
+
auc = roc_auc_score(Y, predictions)
|
65 |
+
precision = precision_score(Y, predictions_binary)
|
66 |
+
recall = recall_score(Y, predictions_binary)
|
67 |
+
f1 = f1_score(Y, predictions_binary)
|
68 |
+
return auc, precision, recall, f1, predictions
|
69 |
+
|
70 |
+
# Function to evaluate the model on a given dataset
|
71 |
+
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
|
72 |
+
eval_metrics = model.evaluate(X, Y, verbose=0)
|
73 |
+
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
|
74 |
+
metrics = {
|
75 |
+
'loss': eval_metrics[0],
|
76 |
+
'accuracy': eval_metrics[1],
|
77 |
+
'auc': auc,
|
78 |
+
'precision': precision,
|
79 |
+
'recall': recall,
|
80 |
+
'f1_score': f1
|
81 |
+
}
|
82 |
+
|
83 |
+
# Save the predictions for each sample
|
84 |
+
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
|
85 |
+
|
86 |
+
return metrics
|
87 |
+
|
88 |
+
# Function to evaluate the model on train, validate, and test datasets
|
89 |
+
def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
|
90 |
+
train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
|
91 |
+
validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
|
92 |
+
test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
|
93 |
+
|
94 |
+
metrics = {
|
95 |
+
'train': train_metrics,
|
96 |
+
'validate': validate_metrics,
|
97 |
+
'test': test_metrics
|
98 |
+
}
|
99 |
+
|
100 |
+
# Display the metrics in a tabular format
|
101 |
+
metrics_df = pd.DataFrame(metrics).T
|
102 |
+
print(metrics_df.to_string())
|
103 |
+
|
104 |
+
# Save metrics to a JSON file
|
105 |
+
with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
|
106 |
+
json.dump(metrics, f, indent=4)
|
107 |
+
|
108 |
+
print("Evaluation metrics saved to evaluation_metrics.json")
|
109 |
+
|
110 |
+
return metrics
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
# Command line arguments
|
114 |
+
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
|
115 |
+
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
|
116 |
+
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model and evaluation metrics.')
|
117 |
+
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
|
118 |
+
|
119 |
+
args = parser.parse_args()
|
120 |
+
|
121 |
+
if not os.path.exists(args.save_dir):
|
122 |
+
os.makedirs(args.save_dir)
|
123 |
+
|
124 |
+
# Load the preprocessed data
|
125 |
+
data = np.load(args.data_file)
|
126 |
+
train_X, train_Y = data['train_X'], data['train_Y']
|
127 |
+
validate_X, validate_Y = data['validate_X'], data['validate_Y']
|
128 |
+
test_X, test_Y = data['test_X'], data['test_Y']
|
129 |
+
|
130 |
+
# Create the model
|
131 |
+
instance_shape = (train_X.shape[-1],)
|
132 |
+
max_length = train_X.shape[1]
|
133 |
+
model = create_simple_model(instance_shape, max_length)
|
134 |
+
|
135 |
+
# Train the model
|
136 |
+
trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
|
137 |
+
|
138 |
+
# Save the final model after training
|
139 |
+
final_model_path = os.path.join(args.save_dir, "risk_classifier_model.h5")
|
140 |
+
trained_model.save(final_model_path)
|
141 |
+
print(f"Model saved successfully to {final_model_path}")
|
142 |
+
|
143 |
+
# Evaluate the model
|
144 |
+
metrics = evaluate_all_datasets(trained_model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)
|
train_omics_plip_model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
from torch import optim
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from scripts.genomic_plip_model import GenomicPLIPModel
|
6 |
+
from scripts.tile_file_dataloader import FlatTileDataset
|
7 |
+
from transformers import CLIPVisionModel
|
8 |
+
|
9 |
+
def train_model(data_dir, model_save_path, pretrained_model_path, lr, num_epochs, train_batch_size, validation_batch_size, num_workers):
|
10 |
+
|
11 |
+
# Load datasets
|
12 |
+
train_dataset = FlatTileDataset(data_dir=f'{data_dir}/train')
|
13 |
+
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
|
14 |
+
|
15 |
+
validation_dataset = FlatTileDataset(data_dir=f'{data_dir}/validate')
|
16 |
+
validation_data_loader = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=False, num_workers=num_workers)
|
17 |
+
|
18 |
+
# Initialize the model
|
19 |
+
base_model = CLIPVisionModel.from_pretrained(pretrained_model_path)
|
20 |
+
custom_model = GenomicPLIPModel(base_model)
|
21 |
+
|
22 |
+
criterion = torch.nn.CosineSimilarity(dim=1)
|
23 |
+
optimizer = optim.Adam(custom_model.parameters(), lr=lr)
|
24 |
+
|
25 |
+
|
26 |
+
for epoch in range(num_epochs):
|
27 |
+
# Training loop
|
28 |
+
custom_model.train()
|
29 |
+
train_loss = 0.0
|
30 |
+
|
31 |
+
for batch_images, batch_scores in train_data_loader:
|
32 |
+
optimizer.zero_grad()
|
33 |
+
|
34 |
+
batch_loss = 0
|
35 |
+
for img, score in zip(batch_images, batch_scores):
|
36 |
+
vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
|
37 |
+
cos_sim = criterion(score_features, vision_features)
|
38 |
+
loss = -cos_sim.mean()
|
39 |
+
|
40 |
+
batch_loss += loss.item()
|
41 |
+
loss.backward()
|
42 |
+
|
43 |
+
optimizer.step()
|
44 |
+
train_loss += batch_loss
|
45 |
+
print(f"Batch Cosine Similarity {batch_loss:.4f}")
|
46 |
+
|
47 |
+
avg_train_loss = train_loss / len(train_data_loader)
|
48 |
+
print(f"Epoch [{epoch+1}/{num_epochs}], Training Cosine Similarity: {avg_train_loss:.4f}")
|
49 |
+
|
50 |
+
# Validation loop
|
51 |
+
custom_model.eval()
|
52 |
+
validation_loss = 0.0
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
for batch_images, batch_scores in validation_data_loader:
|
56 |
+
batch_loss = 0
|
57 |
+
for img, score in zip(batch_images, batch_scores):
|
58 |
+
vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
|
59 |
+
cos_sim = criterion(score_features, vision_features)
|
60 |
+
loss = -cos_sim.mean()
|
61 |
+
|
62 |
+
batch_loss += loss.item()
|
63 |
+
|
64 |
+
validation_loss += batch_loss
|
65 |
+
print(f"Validation Batch Cosine Similarity {batch_loss:.4f}")
|
66 |
+
|
67 |
+
avg_validation_loss = validation_loss / len(validation_data_loader)
|
68 |
+
print(f"Epoch [{epoch+1}/{num_epochs}], Validation Cosine Similarity: {avg_validation_loss:.4f}")
|
69 |
+
|
70 |
+
# Save the trained model
|
71 |
+
torch.save(custom_model.state_dict(), model_save_path)
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
parser = argparse.ArgumentParser(description='Train the Genomic PLIP Model')
|
75 |
+
parser.add_argument('--data_dir', type=str, default='Datasets/train_03', help='Directory containing the train, validate, and test datasets.')
|
76 |
+
parser.add_argument('--model_save_path', type=str, default='genomic_plip.pth', help='Path to save the trained model.')
|
77 |
+
parser.add_argument('--pretrained_model_path', type=str, default='./plip', help='Path to the pretrained CLIP model.')
|
78 |
+
|
79 |
+
parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate for the optimizer.')
|
80 |
+
parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs to train for.')
|
81 |
+
parser.add_argument('--train_batch_size', type=int, default=128, help='Batch size for the training data loader.')
|
82 |
+
parser.add_argument('--validation_batch_size', type=int, default=128, help='Batch size for the validation data loader.')
|
83 |
+
parser.add_argument('--num_workers', type=int, default=32, help='Number of worker threads for data loading.')
|
84 |
+
|
85 |
+
|
86 |
+
args = parser.parse_args()
|
87 |
+
|
88 |
+
train_model(args.data_dir, args.model_save_path, args.pretrained_model_path, args.lr, args.num_epochs, args.train_batch_size, args.validation_batch_size, args.num_workers)
|
89 |
+
|
train_risk_classifier.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers, Model
|
5 |
+
import argparse
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
# Define the function to create the multiple instance learning (MIL) model
|
9 |
+
def create_simple_model2(instance_shape, max_length):
|
10 |
+
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
|
11 |
+
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
|
12 |
+
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
|
13 |
+
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
|
14 |
+
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
|
15 |
+
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
|
16 |
+
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
|
17 |
+
norm_1 = layers.LayerNormalization()(aggregated)
|
18 |
+
output = layers.Dense(1, activation="sigmoid")(norm_1)
|
19 |
+
return Model(inputs, output)
|
20 |
+
|
21 |
+
def create_simple_model(instance_shape, max_length, num_heads=4, key_dim=64):
|
22 |
+
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
|
23 |
+
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
|
24 |
+
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
|
25 |
+
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
|
26 |
+
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
|
27 |
+
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
|
28 |
+
attention_output, attention_scores = layers.MultiHeadAttention(
|
29 |
+
num_heads=num_heads,
|
30 |
+
key_dim=key_dim,
|
31 |
+
value_dim=64,
|
32 |
+
dropout=0.1,
|
33 |
+
use_bias=True
|
34 |
+
)(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
|
35 |
+
aggregated = layers.GlobalAveragePooling1D()(attention_output)
|
36 |
+
norm_1 = layers.LayerNormalization()(aggregated)
|
37 |
+
output = layers.Dense(1, activation="sigmoid")(norm_1)
|
38 |
+
return Model(inputs, output)
|
39 |
+
|
40 |
+
# Function to compute class weights
|
41 |
+
def compute_class_weights(labels):
|
42 |
+
negative_count = len(np.where(labels == 0)[0])
|
43 |
+
positive_count = len(np.where(labels == 1)[0])
|
44 |
+
total_count = negative_count + positive_count
|
45 |
+
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
|
46 |
+
|
47 |
+
# Function to generate batches of data
|
48 |
+
def data_generator(data, labels, batch_size=1):
|
49 |
+
class_weights = compute_class_weights(labels)
|
50 |
+
while True:
|
51 |
+
for i in range(0, len(data), batch_size):
|
52 |
+
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
|
53 |
+
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
|
54 |
+
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
|
55 |
+
yield batch_data, batch_labels, batch_weights
|
56 |
+
|
57 |
+
# Learning rate scheduler
|
58 |
+
def lr_scheduler(epoch, lr):
|
59 |
+
decay_rate = 0.1
|
60 |
+
decay_step = 10
|
61 |
+
if epoch % decay_step == 0 and epoch:
|
62 |
+
return lr * decay_rate
|
63 |
+
return lr
|
64 |
+
|
65 |
+
# Function to train the model
|
66 |
+
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
|
67 |
+
model_path = os.path.join(save_dir, "risk_classifier_model.h5")
|
68 |
+
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
|
69 |
+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
|
70 |
+
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
|
71 |
+
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
|
72 |
+
train_gen = data_generator(train_data, train_labels)
|
73 |
+
val_gen = data_generator(val_data, val_labels)
|
74 |
+
model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
|
75 |
+
return model
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
# Command line arguments
|
79 |
+
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
|
80 |
+
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
|
81 |
+
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
|
82 |
+
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
|
83 |
+
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
if not os.path.exists(args.save_dir):
|
87 |
+
os.makedirs(args.save_dir)
|
88 |
+
|
89 |
+
# Load the preprocessed data
|
90 |
+
data = np.load(args.data_file)
|
91 |
+
train_X, train_Y = data['train_X'], data['train_Y']
|
92 |
+
validate_X, validate_Y = data['validate_X'], data['validate_Y']
|
93 |
+
|
94 |
+
# Create the model
|
95 |
+
instance_shape = (train_X.shape[-1],)
|
96 |
+
max_length = train_X.shape[1]
|
97 |
+
model = create_simple_model(instance_shape, max_length)
|
98 |
+
|
99 |
+
# Train the model
|
100 |
+
trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
|
101 |
+
|
102 |
+
# Final message after training and saving the model
|
103 |
+
print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")
|
train_risk_classifier_optional.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers, Model
|
5 |
+
import argparse
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
# Define the function to create the first model
|
9 |
+
def create_simple_model(instance_shape, max_length):
|
10 |
+
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
|
11 |
+
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
|
12 |
+
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
|
13 |
+
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
|
14 |
+
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
|
15 |
+
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
|
16 |
+
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
|
17 |
+
norm_1 = layers.LayerNormalization()(aggregated)
|
18 |
+
output = layers.Dense(1, activation="sigmoid")(norm_1)
|
19 |
+
return Model(inputs, output)
|
20 |
+
|
21 |
+
# Define the function to create the second model with attention
|
22 |
+
def create_simple_model2(instance_shape, max_length, num_heads=4, key_dim=64):
|
23 |
+
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
|
24 |
+
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
|
25 |
+
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
|
26 |
+
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
|
27 |
+
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
|
28 |
+
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
|
29 |
+
attention_output, attention_scores = layers.MultiHeadAttention(
|
30 |
+
num_heads=num_heads,
|
31 |
+
key_dim=key_dim,
|
32 |
+
value_dim=64,
|
33 |
+
dropout=0.1,
|
34 |
+
use_bias=True
|
35 |
+
)(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
|
36 |
+
aggregated = layers.GlobalAveragePooling1D()(attention_output)
|
37 |
+
norm_1 = layers.LayerNormalization()(aggregated)
|
38 |
+
output = layers.Dense(1, activation="sigmoid")(norm_1)
|
39 |
+
return Model(inputs, output)
|
40 |
+
|
41 |
+
# Function to compute class weights
|
42 |
+
def compute_class_weights(labels):
|
43 |
+
negative_count = len(np.where(labels == 0)[0])
|
44 |
+
positive_count = len(np.where(labels == 1)[0])
|
45 |
+
total_count = negative_count + positive_count
|
46 |
+
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
|
47 |
+
|
48 |
+
# Function to generate batches of data
|
49 |
+
def data_generator(data, labels, batch_size=1):
|
50 |
+
class_weights = compute_class_weights(labels)
|
51 |
+
while True:
|
52 |
+
for i in range(0, len(data), batch_size):
|
53 |
+
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
|
54 |
+
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
|
55 |
+
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
|
56 |
+
yield batch_data, batch_labels, batch_weights
|
57 |
+
|
58 |
+
# Learning rate scheduler
|
59 |
+
def lr_scheduler(epoch, lr):
|
60 |
+
decay_rate = 0.1
|
61 |
+
decay_step = 10
|
62 |
+
if epoch % decay_step == 0 and epoch:
|
63 |
+
return lr * decay_rate
|
64 |
+
return lr
|
65 |
+
|
66 |
+
# Function to train the model
|
67 |
+
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
|
68 |
+
model_path = os.path.join(save_dir, "risk_classifier_model.h5")
|
69 |
+
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
|
70 |
+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
|
71 |
+
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
|
72 |
+
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
|
73 |
+
train_gen = data_generator(train_data, train_labels)
|
74 |
+
val_gen = data_generator(val_data, val_labels)
|
75 |
+
model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
|
76 |
+
return model
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
# Command line arguments
|
80 |
+
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
|
81 |
+
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
|
82 |
+
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
|
83 |
+
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
|
84 |
+
parser.add_argument('--model_type', type=str, default='model1', choices=['model1', 'model2'], help='Type of model to use: model1 (default) or model2.')
|
85 |
+
|
86 |
+
args = parser.parse_args()
|
87 |
+
|
88 |
+
if not os.path.exists(args.save_dir):
|
89 |
+
os.makedirs(args.save_dir)
|
90 |
+
|
91 |
+
# Load the preprocessed data
|
92 |
+
data = np.load(args.data_file)
|
93 |
+
train_X, train_Y = data['train_X'], data['train_Y']
|
94 |
+
validate_X, validate_Y = data['validate_X'], data['validate_Y']
|
95 |
+
|
96 |
+
# Create the model based on the selected type
|
97 |
+
instance_shape = (train_X.shape[-1],)
|
98 |
+
max_length = train_X.shape[1]
|
99 |
+
|
100 |
+
if args.model_type == 'model2':
|
101 |
+
model = create_simple_model2(instance_shape, max_length)
|
102 |
+
else:
|
103 |
+
model = create_simple_model(instance_shape, max_length)
|
104 |
+
|
105 |
+
# Train the model
|
106 |
+
trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
|
107 |
+
|
108 |
+
# Final message after training and saving the model
|
109 |
+
print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")
|