""" Geneformer classifier. **Input data:** | Cell state classifier: | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py) | Gene classifier: | Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py) **Usage:** .. code-block :: python >>> from geneformer import Classifier >>> cc = Classifier(classifier="cell", # example of cell state classifier ... cell_state_dict={"state_key": "disease", "states": "all"}, ... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}, ... training_args=training_args, ... freeze_layers = 2, ... num_crossval_splits = 1, ... forward_batch_size=200, ... nproc=16) >>> cc.prepare_data(input_data_file="path/to/input_data", ... output_directory="path/to/output_directory", ... output_prefix="output_prefix") >>> all_metrics = cc.validate(model_directory="path/to/model", ... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset", ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl", ... output_directory="path/to/output_directory", ... output_prefix="output_prefix", ... predict_eval=True) >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]}, ... output_directory="path/to/output_directory", ... output_prefix="output_prefix", ... custom_class_order=["healthy","disease1","disease2"]) >>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl", ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl", ... title="disease", ... output_directory="path/to/output_directory", ... output_prefix="output_prefix", ... custom_class_order=["healthy","disease1","disease2"]) """ import datetime import logging import os import pickle import subprocess from pathlib import Path import numpy as np import pandas as pd import seaborn as sns from sklearn.model_selection import StratifiedKFold from tqdm.auto import tqdm, trange from transformers import Trainer from transformers.training_args import TrainingArguments from . import DataCollatorForCellClassification, DataCollatorForGeneClassification from . import classifier_utils as cu from . import evaluation_utils as eu from . import perturber_utils as pu from .tokenizer import TOKEN_DICTIONARY_FILE sns.set() logger = logging.getLogger(__name__) class Classifier: valid_option_dict = { "classifier": {"cell", "gene"}, "cell_state_dict": {None, dict}, "gene_class_dict": {None, dict}, "filter_data": {None, dict}, "rare_threshold": {int, float}, "max_ncells": {None, int}, "max_ncells_per_class": {None, int}, "training_args": {None, dict}, "freeze_layers": {int}, "num_crossval_splits": {0, 1, 5}, "split_sizes": {None, dict}, "no_eval": {bool}, "stratify_splits_col": {None, str}, "forward_batch_size": {int}, "nproc": {int}, "ngpu": {int}, } def __init__( self, classifier=None, cell_state_dict=None, gene_class_dict=None, filter_data=None, rare_threshold=0, max_ncells=None, max_ncells_per_class=None, training_args=None, ray_config=None, freeze_layers=0, num_crossval_splits=1, split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1}, stratify_splits_col=None, no_eval=False, forward_batch_size=100, nproc=4, ngpu=1, ): """ Initialize Geneformer classifier. **Parameters:** classifier : {"cell", "gene"} | Whether to fine-tune a cell state or gene classifier. cell_state_dict : None, dict | Cell states to fine-tune model to distinguish. | Two-item dictionary with keys: state_key and states | state_key: key specifying name of column in .dataset that defines the states to model | states: list of values in the state_key column that specifies the states to model | Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data. | Of note, if using "all", states will be defined after data is filtered. | Must have at least 2 states to model. | For example: {"state_key": "disease", | "states": ["nf", "hcm", "dcm"]} | or | {"state_key": "disease", | "states": "all"} gene_class_dict : None, dict | Gene classes to fine-tune model to distinguish. | Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...), | Gene_label_B: list(geneB1, geneB2, ...)} | Gene values should be Ensembl IDs. filter_data : None, dict | Default is to fine-tune with all input data. | Otherwise, dictionary specifying .dataset column name and list of values to filter by. rare_threshold : float | Threshold below which rare cell states should be removed. | For example, setting to 0.05 will remove cell states representing | < 5% of the total cells from the cell state classifier's possible classes. max_ncells : None, int | Maximum number of cells to use for fine-tuning. | Default is to fine-tune with all input data. max_ncells_per_class : None, int | Maximum number of cells per cell class to use for fine-tuning. | Of note, will be applied after max_ncells above. | (Only valid for cell classification.) training_args : None, dict | Training arguments for fine-tuning. | If None, defaults will be inferred for 6 layer Geneformer. | Otherwise, will use the Hugging Face defaults: | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments | Note: Hyperparameter tuning is highly recommended, rather than using defaults. ray_config : None, dict | Training argument ranges for tuning hyperparameters with Ray. freeze_layers : int | Number of layers to freeze from fine-tuning. | 0: no layers will be frozen; 2: first two layers will be frozen; etc. num_crossval_splits : {0, 1, 5} | 0: train on all data without splitting | 1: split data into train and eval sets by designated split_sizes["valid"] | 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"] split_sizes : None, dict | Dictionary of proportion of data to hold out for train, validation, and test sets | {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split stratify_splits_col : None, str | Name of column in .dataset to be used for stratified splitting. | Proportion of each class in this column will be the same in the splits as in the original dataset. no_eval : bool | If True, will skip eval step and use all data for training. | Otherwise, will perform eval during training. forward_batch_size : int | Batch size for forward pass (for evaluation, not training). nproc : int | Number of CPU processes to use. ngpu : int | Number of GPUs available. """ self.classifier = classifier self.cell_state_dict = cell_state_dict self.gene_class_dict = gene_class_dict self.filter_data = filter_data self.rare_threshold = rare_threshold self.max_ncells = max_ncells self.max_ncells_per_class = max_ncells_per_class self.training_args = training_args self.ray_config = ray_config self.freeze_layers = freeze_layers self.num_crossval_splits = num_crossval_splits self.split_sizes = split_sizes self.train_size = self.split_sizes["train"] self.valid_size = self.split_sizes["valid"] self.oos_test_size = self.split_sizes["test"] self.eval_size = self.valid_size / (self.train_size + self.valid_size) self.stratify_splits_col = stratify_splits_col self.no_eval = no_eval self.forward_batch_size = forward_batch_size self.nproc = nproc self.ngpu = ngpu if self.training_args is None: logger.warning( "Hyperparameter tuning is highly recommended for optimal results. " "No training_args provided; using default hyperparameters." ) self.validate_options() if self.filter_data is None: self.filter_data = dict() if self.classifier == "cell": if self.cell_state_dict["states"] != "all": self.filter_data[ self.cell_state_dict["state_key"] ] = self.cell_state_dict["states"] # load token dictionary (Ensembl IDs:token) with open(TOKEN_DICTIONARY_FILE, "rb") as f: self.gene_token_dict = pickle.load(f) self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} # filter genes for gene classification for those in token dictionary if self.classifier == "gene": all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values())) missing_genes = [ gene for gene in all_gene_class_values if gene not in self.gene_token_dict.keys() ] if len(missing_genes) == len(all_gene_class_values): logger.error( "None of the provided genes to classify are in token dictionary." ) raise elif len(missing_genes) > 0: logger.warning( f"Genes to classify {missing_genes} are not in token dictionary." ) self.gene_class_dict = { k: set([self.gene_token_dict.get(gene) for gene in v]) for k, v in self.gene_class_dict.items() } empty_classes = [] for k, v in self.gene_class_dict.items(): if len(v) == 0: empty_classes += [k] if len(empty_classes) > 0: logger.error( f"Class(es) {empty_classes} did not contain any genes in the token dictionary." ) raise def validate_options(self): # confirm arguments are within valid options and compatible with each other for attr_name, valid_options in self.valid_option_dict.items(): attr_value = self.__dict__[attr_name] if not isinstance(attr_value, (list, dict)): if attr_value in valid_options: continue valid_type = False for option in valid_options: if (option in [int, float, list, dict, bool]) and isinstance( attr_value, option ): valid_type = True break if valid_type: continue logger.error( f"Invalid option for {attr_name}. " f"Valid options for {attr_name}: {valid_options}" ) raise if self.filter_data is not None: for key, value in self.filter_data.items(): if not isinstance(value, list): self.filter_data[key] = [value] logger.warning( "Values in filter_data dict must be lists. " f"Changing {key} value to list ([{value}])." ) if self.classifier == "cell": if set(self.cell_state_dict.keys()) != set(["state_key", "states"]): logger.error( "Invalid keys for cell_state_dict. " "The cell_state_dict should have only 2 keys: state_key and states" ) raise if self.cell_state_dict["states"] != "all": if not isinstance(self.cell_state_dict["states"], list): logger.error( "States in cell_state_dict should be list of states to model." ) raise if len(self.cell_state_dict["states"]) < 2: logger.error( "States in cell_state_dict should contain at least 2 states to classify." ) raise if self.classifier == "gene": if len(self.gene_class_dict.keys()) < 2: logger.error( "Gene_class_dict should contain at least 2 gene classes to classify." ) raise if sum(self.split_sizes.values()) != 1: logger.error("Train, validation, and test proportions should sum to 1.") raise def prepare_data( self, input_data_file, output_directory, output_prefix, split_id_dict=None, test_size=None, attr_to_split=None, attr_to_balance=None, max_trials=100, pval_threshold=0.1, ): """ Prepare data for cell state or gene classification. **Parameters** input_data_file : Path | Path to directory containing .dataset input output_directory : Path | Path to directory where prepared data will be saved output_prefix : str | Prefix for output file split_id_dict : None, dict | Dictionary of IDs for train and test splits | Three-item dictionary with keys: attr_key, train, test | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits | train: list of IDs in the attr_key column to include in the train split | test: list of IDs in the attr_key column to include in the test split | For example: {"attr_key": "individual", | "train": ["patient1", "patient2", "patient3", "patient4"], | "test": ["patient5", "patient6"]} test_size : None, float | Proportion of data to be saved separately and held out for test set | (e.g. 0.2 if intending hold out 20%) | If None, will inherit from split_sizes["test"] from Classifier | The training set will be further split to train / validation in self.validate | Note: only available for CellClassifiers attr_to_split : None, str | Key for attribute on which to split data while balancing potential confounders | e.g. "patient_id" for splitting by patient while balancing other characteristics | Note: only available for CellClassifiers attr_to_balance : None, list | List of attribute keys on which to balance data while splitting on attr_to_split | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient | Note: only available for CellClassifiers max_trials : None, int | Maximum number of trials of random splitting to try to achieve balanced other attributes | If no split is found without significant (p<0.05) differences in other attributes, will select best | Note: only available for CellClassifiers pval_threshold : None, float | P-value threshold to use for attribute balancing across splits | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance """ if test_size is None: test_size = self.oos_test_size # prepare data and labels for classification data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file) if self.classifier == "cell": if "label" in data.features: logger.error( "Column name 'label' must be reserved for class IDs. Please rename column." ) raise elif self.classifier == "gene": if "labels" in data.features: logger.error( "Column name 'labels' must be reserved for class IDs. Please rename column." ) raise if self.classifier == "cell": # remove cell states representing < rare_threshold of cells data = cu.remove_rare( data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc ) # downsample max cells and max per class data = cu.downsample_and_shuffle( data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict ) # rename cell state column to "label" data = cu.rename_cols(data, self.cell_state_dict["state_key"]) # convert classes to numerical labels and save as id_class_dict # of note, will label all genes in gene_class_dict # if (cross-)validating, genes will be relabeled in column "labels" for each split # at the time of training with Classifier.validate data, id_class_dict = cu.label_classes( self.classifier, data, self.gene_class_dict, self.nproc ) # save id_class_dict for future reference id_class_output_path = ( Path(output_directory) / f"{output_prefix}_id_class_dict" ).with_suffix(".pkl") with open(id_class_output_path, "wb") as f: pickle.dump(id_class_dict, f) if split_id_dict is not None: data_dict = dict() data_dict["train"] = pu.filter_by_dict( data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc ) data_dict["test"] = pu.filter_by_dict( data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc ) train_data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled_train" ).with_suffix(".dataset") test_data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled_test" ).with_suffix(".dataset") data_dict["train"].save_to_disk(train_data_output_path) data_dict["test"].save_to_disk(test_data_output_path) elif (test_size is not None) and (self.classifier == "cell"): if 1 > test_size > 0: if attr_to_split is None: data_dict = data.train_test_split( test_size=test_size, stratify_by_column=self.stratify_splits_col, seed=42, ) train_data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled_train" ).with_suffix(".dataset") test_data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled_test" ).with_suffix(".dataset") data_dict["train"].save_to_disk(train_data_output_path) data_dict["test"].save_to_disk(test_data_output_path) else: data_dict, balance_df = cu.balance_attr_splits( data, attr_to_split, attr_to_balance, test_size, max_trials, pval_threshold, self.cell_state_dict["state_key"], self.nproc, ) balance_df.to_csv( f"{output_directory}/{output_prefix}_train_test_balance_df.csv" ) train_data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled_train" ).with_suffix(".dataset") test_data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled_test" ).with_suffix(".dataset") data_dict["train"].save_to_disk(train_data_output_path) data_dict["test"].save_to_disk(test_data_output_path) else: data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled" ).with_suffix(".dataset") data.save_to_disk(data_output_path) print(data_output_path) else: data_output_path = ( Path(output_directory) / f"{output_prefix}_labeled" ).with_suffix(".dataset") data.save_to_disk(data_output_path) def train_all_data( self, model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, save_eval_output=True, ): """ Train cell state or gene classifier using all data. **Parameters** model_directory : Path | Path to directory containing model prepared_input_data_file : Path | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data id_class_dict_file : Path | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data | (dictionary of format: numerical IDs: class_labels) output_directory : Path | Path to directory where model and eval data will be saved output_prefix : str | Prefix for output files save_eval_output : bool | Whether to save cross-fold eval output | Saves as pickle file of dictionary of eval metrics **Output** Returns trainer after fine-tuning with all data. """ ##### Load data and prepare output directory ##### # load numerical id to class dictionary (id:class) with open(id_class_dict_file, "rb") as f: id_class_dict = pickle.load(f) class_id_dict = {v: k for k, v in id_class_dict.items()} # load previously filtered and prepared data data = pu.load_and_filter(None, self.nproc, prepared_input_data_file) data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data # define output directory path current_date = datetime.datetime.now() datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" if output_directory[-1:] != "/": # add slash for dir if not present output_directory = output_directory + "/" output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/" subprocess.call(f"mkdir {output_dir}", shell=True) # get number of classes for classifier num_classes = cu.get_num_classes(id_class_dict) if self.classifier == "gene": targets = pu.flatten_list(self.gene_class_dict.values()) labels = pu.flatten_list( [ [class_id_dict[label]] * len(targets) for label, targets in self.gene_class_dict.items() ] ) assert len(targets) == len(labels) data = cu.prep_gene_classifier_all_data( data, targets, labels, self.max_ncells, self.nproc ) trainer = self.train_classifier( model_directory, num_classes, data, None, output_dir ) return trainer def validate( self, model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, split_id_dict=None, attr_to_split=None, attr_to_balance=None, max_trials=100, pval_threshold=0.1, save_eval_output=True, predict_eval=True, predict_trainer=False, n_hyperopt_trials=0, ): """ (Cross-)validate cell state or gene classifier. **Parameters** model_directory : Path | Path to directory containing model prepared_input_data_file : Path | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data id_class_dict_file : Path | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data | (dictionary of format: numerical IDs: class_labels) output_directory : Path | Path to directory where model and eval data will be saved output_prefix : str | Prefix for output files split_id_dict : None, dict | Dictionary of IDs for train and eval splits | Three-item dictionary with keys: attr_key, train, eval | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits | train: list of IDs in the attr_key column to include in the train split | eval: list of IDs in the attr_key column to include in the eval split | For example: {"attr_key": "individual", | "train": ["patient1", "patient2", "patient3", "patient4"], | "eval": ["patient5", "patient6"]} | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1) attr_to_split : None, str | Key for attribute on which to split data while balancing potential confounders | e.g. "patient_id" for splitting by patient while balancing other characteristics | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1) attr_to_balance : None, list | List of attribute keys on which to balance data while splitting on attr_to_split | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient max_trials : None, int | Maximum number of trials of random splitting to try to achieve balanced other attribute | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best pval_threshold : None, float | P-value threshold to use for attribute balancing across splits | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance save_eval_output : bool | Whether to save cross-fold eval output | Saves as pickle file of dictionary of eval metrics predict_eval : bool | Whether or not to save eval predictions | Saves as a pickle file of self.evaluate predictions predict_trainer : bool | Whether or not to save eval predictions from trainer | Saves as a pickle file of trainer predictions n_hyperopt_trials : int | Number of trials to run for hyperparameter optimization | If 0, will not optimize hyperparameters """ if self.num_crossval_splits == 0: logger.error("num_crossval_splits must be 1 or 5 to validate.") raise # ensure number of genes in each class is > 5 if validating model if self.classifier == "gene": insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5] if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0): logger.error( f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate." ) raise ##### Load data and prepare output directory ##### # load numerical id to class dictionary (id:class) with open(id_class_dict_file, "rb") as f: id_class_dict = pickle.load(f) class_id_dict = {v: k for k, v in id_class_dict.items()} # load previously filtered and prepared data data = pu.load_and_filter(None, self.nproc, prepared_input_data_file) data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data # define output directory path current_date = datetime.datetime.now() datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" if output_directory[-1:] != "/": # add slash for dir if not present output_directory = output_directory + "/" output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/" subprocess.call(f"mkdir {output_dir}", shell=True) # get number of classes for classifier num_classes = cu.get_num_classes(id_class_dict) ##### (Cross-)validate the model ##### results = [] all_conf_mat = np.zeros((num_classes, num_classes)) iteration_num = 1 if self.classifier == "cell": for i in trange(self.num_crossval_splits): print( f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n" ) ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") if self.num_crossval_splits == 1: # single 1-eval_size:eval_size split if split_id_dict is not None: data_dict = dict() data_dict["train"] = pu.filter_by_dict( data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc, ) data_dict["test"] = pu.filter_by_dict( data, {split_id_dict["attr_key"]: split_id_dict["eval"]}, self.nproc, ) elif attr_to_split is not None: data_dict, balance_df = cu.balance_attr_splits( data, attr_to_split, attr_to_balance, self.eval_size, max_trials, pval_threshold, self.cell_state_dict["state_key"], self.nproc, ) balance_df.to_csv( f"{output_dir}/{output_prefix}_train_valid_balance_df.csv" ) else: data_dict = data.train_test_split( test_size=self.eval_size, stratify_by_column=self.stratify_splits_col, seed=42, ) train_data = data_dict["train"] eval_data = data_dict["test"] else: # 5-fold cross-validate num_cells = len(data) fifth_cells = num_cells * 0.2 num_eval = min((self.eval_size * num_cells), fifth_cells) start = i * fifth_cells end = start + num_eval eval_indices = [j for j in range(start, end)] train_indices = [ j for j in range(num_cells) if j not in eval_indices ] eval_data = data.select(eval_indices) train_data = data.select(train_indices) if n_hyperopt_trials == 0: trainer = self.train_classifier( model_directory, num_classes, train_data, eval_data, ksplit_output_dir, predict_trainer, ) else: trainer = self.hyperopt_classifier( model_directory, num_classes, train_data, eval_data, ksplit_output_dir, n_trials=n_hyperopt_trials, ) if iteration_num == self.num_crossval_splits: return else: iteration_num = iteration_num + 1 continue result = self.evaluate_model( trainer.model, num_classes, id_class_dict, eval_data, predict_eval, ksplit_output_dir, output_prefix, ) results += [result] all_conf_mat = all_conf_mat + result["conf_mat"] iteration_num = iteration_num + 1 elif self.classifier == "gene": # set up (cross-)validation splits targets = pu.flatten_list(self.gene_class_dict.values()) labels = pu.flatten_list( [ [class_id_dict[label]] * len(targets) for label, targets in self.gene_class_dict.items() ] ) assert len(targets) == len(labels) n_splits = int(1 / self.eval_size) skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True) # (Cross-)validate for train_index, eval_index in tqdm(skf.split(targets, labels)): print( f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n" ) ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") # filter data for examples containing classes for this split # subsample to max_ncells and relabel data in column "labels" train_data, eval_data = cu.prep_gene_classifier_split( data, targets, labels, train_index, eval_index, self.max_ncells, iteration_num, self.nproc, ) if n_hyperopt_trials == 0: trainer = self.train_classifier( model_directory, num_classes, train_data, eval_data, ksplit_output_dir, predict_trainer, ) else: trainer = self.hyperopt_classifier( model_directory, num_classes, train_data, eval_data, ksplit_output_dir, n_trials=n_hyperopt_trials, ) if iteration_num == self.num_crossval_splits: return else: iteration_num = iteration_num + 1 continue result = self.evaluate_model( trainer.model, num_classes, id_class_dict, eval_data, predict_eval, ksplit_output_dir, output_prefix, ) results += [result] all_conf_mat = all_conf_mat + result["conf_mat"] # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size if iteration_num == self.num_crossval_splits: break iteration_num = iteration_num + 1 all_conf_mat_df = pd.DataFrame( all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values() ) all_metrics = { "conf_matrix": all_conf_mat_df, "macro_f1": [result["macro_f1"] for result in results], "acc": [result["acc"] for result in results], } all_roc_metrics = None # roc metrics not reported for multiclass if num_classes == 2: mean_fpr = np.linspace(0, 1, 100) all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results] all_roc_auc = [result["roc_metrics"]["auc"] for result in results] all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results] mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics( all_tpr, all_roc_auc, all_tpr_wt ) all_roc_metrics = { "mean_tpr": mean_tpr, "mean_fpr": mean_fpr, "all_roc_auc": all_roc_auc, "roc_auc": roc_auc, "roc_auc_sd": roc_auc_sd, } all_metrics["all_roc_metrics"] = all_roc_metrics if save_eval_output is True: eval_metrics_output_path = ( Path(output_dir) / f"{output_prefix}_eval_metrics_dict" ).with_suffix(".pkl") with open(eval_metrics_output_path, "wb") as f: pickle.dump(all_metrics, f) return all_metrics def hyperopt_classifier( self, model_directory, num_classes, train_data, eval_data, output_directory, n_trials=100, ): """ Fine-tune model for cell state or gene classification. **Parameters** model_directory : Path | Path to directory containing model num_classes : int | Number of classes for classifier train_data : Dataset | Loaded training .dataset input | For cell classifier, labels in column "label". | For gene classifier, labels in column "labels". eval_data : None, Dataset | (Optional) Loaded evaluation .dataset input | For cell classifier, labels in column "label". | For gene classifier, labels in column "labels". output_directory : Path | Path to directory where fine-tuned model will be saved n_trials : int | Number of trials to run for hyperparameter optimization """ # initiate runtime environment for raytune import ray from ray import tune from ray.tune.search.hyperopt import HyperOptSearch ray.shutdown() # engage new ray session ray.init() ##### Validate and prepare data ##### train_data, eval_data = cu.validate_and_clean_cols( train_data, eval_data, self.classifier ) if (self.no_eval is True) and (eval_data is not None): logger.warning( "no_eval set to True; hyperparameter optimization requires eval, proceeding with eval" ) # ensure not overwriting previously saved model saved_model_test = os.path.join(output_directory, "pytorch_model.bin") if os.path.isfile(saved_model_test) is True: logger.error("Model already saved to this designated output directory.") raise # make output directory subprocess.call(f"mkdir {output_directory}", shell=True) ##### Load model and training args ##### if self.classifier == "cell": model_type = "CellClassifier" elif self.classifier == "gene": model_type = "GeneClassifier" model = pu.load_model(model_type, num_classes, model_directory, "train") def_training_args, def_freeze_layers = cu.get_default_train_args( model, self.classifier, train_data, output_directory ) del model if self.training_args is not None: def_training_args.update(self.training_args) logging_steps = round( len(train_data) / def_training_args["per_device_train_batch_size"] / 10 ) def_training_args["logging_steps"] = logging_steps def_training_args["output_dir"] = output_directory if eval_data is None: def_training_args["evaluation_strategy"] = "no" def_training_args["load_best_model_at_end"] = False training_args_init = TrainingArguments(**def_training_args) ##### Fine-tune the model ##### # define the data collator if self.classifier == "cell": data_collator = DataCollatorForCellClassification() elif self.classifier == "gene": data_collator = DataCollatorForGeneClassification() # define function to initiate model def model_init(): model = pu.load_model(model_type, num_classes, model_directory, "train") if self.freeze_layers is not None: def_freeze_layers = self.freeze_layers if def_freeze_layers > 0: modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers] for module in modules_to_freeze: for param in module.parameters(): param.requires_grad = False model = model.to("cuda:0") return model # create the trainer trainer = Trainer( model_init=model_init, args=training_args_init, data_collator=data_collator, train_dataset=train_data, eval_dataset=eval_data, compute_metrics=cu.compute_metrics, ) # specify raytune hyperparameter search space if self.ray_config is None: logger.warning( "No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model." ) def_ray_config = { "num_train_epochs": tune.choice([1]), "learning_rate": tune.loguniform(1e-6, 1e-3), "weight_decay": tune.uniform(0.0, 0.3), "lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]), "warmup_steps": tune.uniform(100, 2000), "seed": tune.uniform(0, 100), "per_device_train_batch_size": tune.choice( [def_training_args["per_device_train_batch_size"]] ), } hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max") # optimize hyperparameters trainer.hyperparameter_search( direction="maximize", backend="ray", resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1}, hp_space=lambda _: def_ray_config if self.ray_config is None else self.ray_config, search_alg=hyperopt_search, n_trials=n_trials, # number of trials progress_reporter=tune.CLIReporter( max_report_frequency=600, sort_by_metric=True, max_progress_rows=n_trials, mode="max", metric="eval_macro_f1", metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"], ), ) return trainer def train_classifier( self, model_directory, num_classes, train_data, eval_data, output_directory, predict=False, ): """ Fine-tune model for cell state or gene classification. **Parameters** model_directory : Path | Path to directory containing model num_classes : int | Number of classes for classifier train_data : Dataset | Loaded training .dataset input | For cell classifier, labels in column "label". | For gene classifier, labels in column "labels". eval_data : None, Dataset | (Optional) Loaded evaluation .dataset input | For cell classifier, labels in column "label". | For gene classifier, labels in column "labels". output_directory : Path | Path to directory where fine-tuned model will be saved predict : bool | Whether or not to save eval predictions from trainer """ ##### Validate and prepare data ##### train_data, eval_data = cu.validate_and_clean_cols( train_data, eval_data, self.classifier ) if (self.no_eval is True) and (eval_data is not None): logger.warning( "no_eval set to True; model will be trained without evaluation." ) eval_data = None if (self.classifier == "gene") and (predict is True): logger.warning( "Predictions during training not currently available for gene classifiers; setting predict to False." ) predict = False # ensure not overwriting previously saved model saved_model_test = os.path.join(output_directory, "pytorch_model.bin") if os.path.isfile(saved_model_test) is True: logger.error("Model already saved to this designated output directory.") raise # make output directory subprocess.call(f"mkdir {output_directory}", shell=True) ##### Load model and training args ##### if self.classifier == "cell": model_type = "CellClassifier" elif self.classifier == "gene": model_type = "GeneClassifier" model = pu.load_model(model_type, num_classes, model_directory, "train") def_training_args, def_freeze_layers = cu.get_default_train_args( model, self.classifier, train_data, output_directory ) if self.training_args is not None: def_training_args.update(self.training_args) logging_steps = round( len(train_data) / def_training_args["per_device_train_batch_size"] / 10 ) def_training_args["logging_steps"] = logging_steps def_training_args["output_dir"] = output_directory if eval_data is None: def_training_args["evaluation_strategy"] = "no" def_training_args["load_best_model_at_end"] = False training_args_init = TrainingArguments(**def_training_args) if self.freeze_layers is not None: def_freeze_layers = self.freeze_layers if def_freeze_layers > 0: modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers] for module in modules_to_freeze: for param in module.parameters(): param.requires_grad = False ##### Fine-tune the model ##### # define the data collator if self.classifier == "cell": data_collator = DataCollatorForCellClassification() elif self.classifier == "gene": data_collator = DataCollatorForGeneClassification() # create the trainer trainer = Trainer( model=model, args=training_args_init, data_collator=data_collator, train_dataset=train_data, eval_dataset=eval_data, compute_metrics=cu.compute_metrics, ) # train the classifier trainer.train() trainer.save_model(output_directory) if predict is True: # make eval predictions and save predictions and metrics predictions = trainer.predict(eval_data) prediction_output_path = f"{output_directory}/predictions.pkl" with open(prediction_output_path, "wb") as f: pickle.dump(predictions, f) trainer.save_metrics("eval", predictions.metrics) return trainer def evaluate_model( self, model, num_classes, id_class_dict, eval_data, predict=False, output_directory=None, output_prefix=None, ): """ Evaluate the fine-tuned model. **Parameters** model : nn.Module | Loaded fine-tuned model (e.g. trainer.model) num_classes : int | Number of classes for classifier id_class_dict : dict | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data | (dictionary of format: numerical IDs: class_labels) eval_data : Dataset | Loaded evaluation .dataset input predict : bool | Whether or not to save eval predictions output_directory : Path | Path to directory where eval data will be saved output_prefix : str | Prefix for output files """ ##### Evaluate the model ##### labels = id_class_dict.keys() y_pred, y_true, logits_list = eu.classifier_predict( model, self.classifier, eval_data, self.forward_batch_size ) conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics( y_pred, y_true, logits_list, num_classes, labels ) if predict is True: pred_dict = { "pred_ids": y_pred, "label_ids": y_true, "predictions": logits_list, } pred_dict_output_path = ( Path(output_directory) / f"{output_prefix}_pred_dict" ).with_suffix(".pkl") with open(pred_dict_output_path, "wb") as f: pickle.dump(pred_dict, f) return { "conf_mat": conf_mat, "macro_f1": macro_f1, "acc": acc, "roc_metrics": roc_metrics, } def evaluate_saved_model( self, model_directory, id_class_dict_file, test_data_file, output_directory, output_prefix, predict=True, ): """ Evaluate the fine-tuned model. **Parameters** model_directory : Path | Path to directory containing model id_class_dict_file : Path | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data | (dictionary of format: numerical IDs: class_labels) test_data_file : Path | Path to directory containing test .dataset output_directory : Path | Path to directory where eval data will be saved output_prefix : str | Prefix for output files predict : bool | Whether or not to save eval predictions """ # load numerical id to class dictionary (id:class) with open(id_class_dict_file, "rb") as f: id_class_dict = pickle.load(f) # get number of classes for classifier num_classes = cu.get_num_classes(id_class_dict) # load previously filtered and prepared data test_data = pu.load_and_filter(None, self.nproc, test_data_file) # load previously fine-tuned model if self.classifier == "cell": model_type = "CellClassifier" elif self.classifier == "gene": model_type = "GeneClassifier" model = pu.load_model(model_type, num_classes, model_directory, "eval") # evaluate the model result = self.evaluate_model( model, num_classes, id_class_dict, test_data, predict=predict, output_directory=output_directory, output_prefix=output_prefix, ) all_conf_mat_df = pd.DataFrame( result["conf_mat"], columns=id_class_dict.values(), index=id_class_dict.values(), ) all_metrics = { "conf_matrix": all_conf_mat_df, "macro_f1": result["macro_f1"], "acc": result["acc"], } all_roc_metrics = None # roc metrics not reported for multiclass if num_classes == 2: mean_fpr = np.linspace(0, 1, 100) mean_tpr = result["roc_metrics"]["interp_tpr"] all_roc_auc = result["roc_metrics"]["auc"] all_roc_metrics = { "mean_tpr": mean_tpr, "mean_fpr": mean_fpr, "all_roc_auc": all_roc_auc, } all_metrics["all_roc_metrics"] = all_roc_metrics test_metrics_output_path = ( Path(output_directory) / f"{output_prefix}_test_metrics_dict" ).with_suffix(".pkl") with open(test_metrics_output_path, "wb") as f: pickle.dump(all_metrics, f) return all_metrics def plot_conf_mat( self, conf_mat_dict, output_directory, output_prefix, custom_class_order=None, ): """ Plot confusion matrix results of evaluating the fine-tuned model. **Parameters** conf_mat_dict : dict | Dictionary of model_name : confusion_matrix_DataFrame | (all_metrics["conf_matrix"] from self.validate) output_directory : Path | Path to directory where plots will be saved output_prefix : str | Prefix for output file custom_class_order : None, list | List of classes in custom order for plots. | Same order will be used for all models. """ for model_name in conf_mat_dict.keys(): eu.plot_confusion_matrix( conf_mat_dict[model_name], model_name, output_directory, output_prefix, custom_class_order, ) def plot_roc( self, roc_metric_dict, model_style_dict, title, output_directory, output_prefix, ): """ Plot ROC curve results of evaluating the fine-tuned model. **Parameters** roc_metric_dict : dict | Dictionary of model_name : roc_metrics | (all_metrics["all_roc_metrics"] from self.validate) model_style_dict : dict[dict] | Dictionary of model_name : dictionary of style_attribute : style | where style includes color and linestyle | e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...} title : str | Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors') output_directory : Path | Path to directory where plots will be saved output_prefix : str | Prefix for output file """ eu.plot_ROC( roc_metric_dict, model_style_dict, title, output_directory, output_prefix ) def plot_predictions( self, predictions_file, id_class_dict_file, title, output_directory, output_prefix, custom_class_order=None, kwargs_dict=None, ): """ Plot prediction results of evaluating the fine-tuned model. **Parameters** predictions_file : path | Path of model predictions output to plot | (saved output from self.validate if predict_eval=True) | (or saved output from self.evaluate_saved_model) id_class_dict_file : Path | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data | (dictionary of format: numerical IDs: class_labels) title : str | Title for legend containing class labels. output_directory : Path | Path to directory where plots will be saved output_prefix : str | Prefix for output file custom_class_order : None, list | List of classes in custom order for plots. | Same order will be used for all models. kwargs_dict : None, dict | Dictionary of kwargs to pass to plotting function. """ # load predictions with open(predictions_file, "rb") as f: predictions = pickle.load(f) # load numerical id to class dictionary (id:class) with open(id_class_dict_file, "rb") as f: id_class_dict = pickle.load(f) if isinstance(predictions, dict): if all( [ key in predictions.keys() for key in ["pred_ids", "label_ids", "predictions"] ] ): # format is output from self.evaluate_saved_model predictions_logits = np.array(predictions["predictions"]) true_ids = predictions["label_ids"] else: # format is output from self.validate if predict_eval=True predictions_logits = predictions.predictions true_ids = predictions.label_ids num_classes = len(id_class_dict.keys()) num_predict_classes = predictions_logits.shape[1] assert num_classes == num_predict_classes classes = id_class_dict.values() true_labels = [id_class_dict[idx] for idx in true_ids] predictions_df = pd.DataFrame(predictions_logits, columns=classes) if custom_class_order is not None: predictions_df = predictions_df.reindex(columns=custom_class_order) predictions_df["true"] = true_labels custom_dict = dict(zip(classes, [i for i in range(len(classes))])) if custom_class_order is not None: custom_dict = dict( zip(custom_class_order, [i for i in range(len(custom_class_order))]) ) predictions_df = predictions_df.sort_values( by=["true"], key=lambda x: x.map(custom_dict) ) eu.plot_predictions( predictions_df, title, output_directory, output_prefix, kwargs_dict )