omics-plip-1 / benchmark_train_resnet50.py
VatsalPatel18's picture
Upload 19 files
70884da verified
raw
history blame
5.67 kB
import os
import numpy as np
import tensorflow as tf
import json
from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
import argparse
import pandas as pd
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
def compute_additional_metrics(generator, model):
y_true = generator.classes
y_pred_prob = model.predict(generator)
y_pred = np.argmax(y_pred_prob, axis=1)
auc = roc_auc_score(y_true, y_pred_prob[:, 1])
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')
accuracy = accuracy_score(y_true, y_pred)
return auc, precision, recall, f1, accuracy, y_pred_prob
# Function to save evaluation metrics
def save_evaluation_metrics(generator, model, dataset_name, save_dir):
auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
metrics = {
'auc': auc,
'precision': precision,
'recall': recall,
'f1_score': f1,
'accuracy': accuracy
}
# Save predictions
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
return metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train and evaluate ResNet50 on benchmark datasets.')
parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')
args = parser.parse_args()
train_dir = os.path.join(args.dataset_dir, 'train')
validate_dir = os.path.join(args.dataset_dir, 'validate')
test_dir = os.path.join(args.dataset_dir, 'test')
test2_dir = os.path.join(args.dataset_dir, 'test2')
os.makedirs(args.save_dir, exist_ok=True)
# Set up ResNet50 model
with tf.device('GPU:0'):
resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
last_layer = resnet.get_layer('conv5_block3_out')
last_output = last_layer.output
x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
x = tf.keras.layers.Dense(2, activation='softmax')(x) # Assuming binary classification
model = tf.keras.Model(inputs=resnet.input, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])
# Image data generators
train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
validate_datagen = IDG(rescale=1/255.0)
test_datagen = IDG(rescale=1/255.0)
batch_size = 64
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
# Training the model
hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)
# Save the trained model
model.save(os.path.join(args.save_dir, 'risk_classifier_resnet_model.hdf5'))
# Save training history separately
training_log = {
'loss': hist.history['loss'],
'val_loss': hist.history['val_loss'],
'accuracy': hist.history['accuracy'],
'val_accuracy': hist.history['val_accuracy'],
'recall': hist.history['recall'],
'val_recall': hist.history['val_recall'],
'precision': hist.history['precision'],
'val_precision': hist.history['val_precision']
}
with open(os.path.join(args.save_dir, 'resnet_training_log.json'), 'w') as f:
json.dump(training_log, f)
# Evaluate the model on each dataset and save metrics
train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)
# Save the evaluation metrics in a JSON file
evaluation_metrics = {
'train_metrics': train_metrics,
'validate_metrics': validate_metrics,
'test_metrics': test_metrics,
'test2_metrics': test2_metrics
}
with open(os.path.join(args.save_dir, 'resnet_evaluation_metrics.json'), 'w') as f:
json.dump(evaluation_metrics, f)
print("Training and evaluation metrics saved.")