| """Perform CV (with explainability) on different feature sets and log to mlflow. |
| |
| Includes functionality to nest runs under parent run (e.g. different feature sets |
| under a main run) and set a decision threshold for model scores. Logs the following |
| artifacts as well as metrics and parameters: |
| 1. List of model features |
| 2. Feature correlation matrix |
| 3. Global explainability (averaged over K folds) |
| 4. Cumulative gains curve |
| 5. Lift curve |
| 6. Probability distributions with KDE |
| """ |
| from imblearn.ensemble import BalancedRandomForestClassifier |
| from lenusml import splits, crossvalidation, plots |
| import numpy as np |
| import os |
| import pandas as pd |
|
|
| import mlflow |
| from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID |
|
|
|
|
| def get_crossvalidation_importance(*, feature_names, crossval): |
| """ |
| Create dataframe of mean global feature importance for all EBMs used in CV. |
| |
| Args: |
| feature_names (list): list of model feature names |
| crossval (dict): output of cross_validation_return_estimator_and_scores |
| |
| Returns: |
| pd.DataFrame: contains feature names, global importance for each of the K |
| estimators, mean importance across the estimators and scaled mean importance |
| relative to the most important feature. |
| """ |
| |
| for i, est in enumerate(crossval['estimator']): |
| exp_global = crossval['estimator'][i].feature_importances_ |
|
|
| explanations = pd.DataFrame([feature_names, exp_global]).T |
| explanations.columns = ['Feature', 'Score_{}'.format(i)] |
|
|
| |
| if i == 0: |
| explanations_all = explanations.copy() |
| else: |
| explanations_all = explanations_all.merge(explanations, on='Feature') |
|
|
| |
| explanations_all['Mean'] = explanations_all.drop(columns=['Feature']).mean(axis=1) |
| explanations_all = explanations_all.sort_values('Mean', ascending=False) |
| |
| explanations_all['Mean_scaled'] = explanations_all['Mean'] /\ |
| explanations_all['Mean'].abs().max() |
| return explanations_all |
|
|
|
|
| data_dir = '../data/models/model1/' |
| cohort_info_dir = '../data/cohort_info/' |
| output_dir = '../data/models/model1/output' |
|
|
| |
| fold_patients = np.load(os.path.join(cohort_info_dir, 'fold_patients.npy'), |
| allow_pickle=True) |
| train_data = pd.read_pickle(os.path.join(data_dir, 'train_data_cv.pkl')) |
|
|
| |
| cross_validation_fold_indices = splits.custom_cv_fold_indices(fold_patients=fold_patients, |
| id_column='StudyId', |
| train_data=train_data) |
|
|
| mlflow.set_tracking_uri("sqlite:///mlruns.sqlite") |
| mlflow.set_experiment('model_drop2') |
|
|
| |
| scoring = ['f1', 'balanced_accuracy', 'accuracy', 'precision', 'recall', 'roc_auc', |
| 'average_precision'] |
| |
| comorbidities = pd.read_csv('<YOUR_DATA_PATH>/copd-dataset/CopdDatasetCoMorbidityDetails.txt', |
| delimiter='|') |
| comorbidity_list = list(comorbidities.columns) |
| comorbidity_list.remove('Id') |
| comorbidity_list.remove('PatientId') |
| comorbidity_list.remove('Created') |
|
|
| |
| patient_details = pd.read_pickle(os.path.join('<YOUR_DATA_PATH>/copd-dataset', |
| 'patient_details.pkl')) |
| comorbidities = comorbidities.merge(patient_details[['PatientId', 'StudyId']], |
| on='PatientId', how='left') |
|
|
| |
| bool_mapping = {True: 1, False: 0} |
| comorbidities[comorbidity_list] = comorbidities[comorbidity_list].replace( |
| bool_mapping) |
|
|
| with mlflow.start_run(run_name='individual_comorbidities_no_binned'): |
| runid = mlflow.active_run().info.run_id |
| |
| for comorbidity in comorbidity_list: |
| print(comorbidity) |
| |
| train_data = train_data.merge(comorbidities[['StudyId', comorbidity]], |
| on='StudyId', how='left') |
| train_data[comorbidity] = train_data[comorbidity].fillna(0) |
|
|
| with mlflow.start_run(run_name=comorbidity, nested=True, |
| tags={MLFLOW_PARENT_RUN_ID: runid}): |
| |
| |
| |
| |
| cols_to_drop = ['StudyId', 'IsExac', 'Comorbidities_te'] |
| features_list = [col for col in train_data.columns if col not in cols_to_drop] |
|
|
| |
| features = train_data[features_list].astype('float') |
| target = train_data.IsExac.astype('float') |
|
|
| |
| |
| artifact_dir = './tmp' |
| |
| os.makedirs(artifact_dir, exist_ok=True) |
| |
| |
| for f in os.listdir(artifact_dir): |
| os.remove(os.path.join(artifact_dir, f)) |
|
|
| np.savetxt(os.path.join(artifact_dir, 'features.txt'), features_list, |
| delimiter=",", fmt='%s') |
|
|
| plots.plot_feature_correlations( |
| features=features, figsize=(len(features_list) // 2, |
| len(features_list) // 2), |
| savefig=True, output_dir=artifact_dir, |
| figname="feature_correlations.png") |
|
|
| |
| model = BalancedRandomForestClassifier(random_state=0) |
| |
| |
| |
|
|
| |
| |
| crossval, model_scores =\ |
| crossvalidation.cross_validation_return_estimator_and_scores( |
| model=model, features=features, |
| target=target, |
| fold_indices=cross_validation_fold_indices) |
|
|
| |
| for score in scoring: |
| mlflow.log_metric(score, np.mean(crossval['test_' + score])) |
|
|
| |
| params = model.get_params() |
| for param in params: |
| mlflow.log_param(param, params[param]) |
|
|
| |
| explainability = get_crossvalidation_importance(feature_names=features_list, |
| crossval=crossval) |
| explainability.to_csv(os.path.join(artifact_dir, |
| 'global_feature_importances.csv'), index=False) |
| plots.plot_global_explainability_cv(importances=explainability, |
| scaled=True, |
| figsize=( |
| len(features_list) // 2.5, |
| len(features_list) // 6), |
| savefig=True, output_dir=artifact_dir) |
| |
| plots.plot_lift_curve(scores=model_scores, savefig=True, |
| output_dir=artifact_dir, figname='lift_curve.png') |
| plots.plot_cumulative_gains_curve(scores=model_scores, savefig=True, |
| output_dir=artifact_dir, |
| figname='cumulative_gains_curve.png') |
|
|
| |
| plots.plot_score_distribution(scores=model_scores, postive_class_name='Exac', |
| negative_class_name='No exac', savefig=True, |
| output_dir=artifact_dir, |
| figname='model_score_distribution.png') |
|
|
| |
| mlflow.log_artifacts(artifact_dir) |
| mlflow.end_run() |
| |
| train_data = train_data.drop(columns=[comorbidity]) |
| |
|
|