| | |
| | |
| |
|
| | from pathlib import Path |
| | import numpy as np |
| | import pandas as pd |
| | from sklearn import metrics |
| | from sklearn.preprocessing import label_binarize |
| | import matplotlib.pyplot as plt |
| | from matplotlib import colors as mcolors |
| | from .logger import BaseLogger |
| | from typing import Dict, Union |
| |
|
| |
|
| | logger = BaseLogger.get_logger(__name__) |
| |
|
| |
|
| | class MetricsData: |
| | """ |
| | Class to store metrics as class variable. |
| | Metrics are defined depending on task. |
| | |
| | For ROC |
| | self.fpr: np.ndarray |
| | self.tpr: np.ndarray |
| | self.auc: float |
| | |
| | For Regression |
| | self.y_obs: np.ndarray |
| | self.y_pred: np.ndarray |
| | self.r2: float |
| | |
| | For DeepSurv |
| | self.c_index: float |
| | """ |
| | def __init__(self) -> None: |
| | pass |
| |
|
| |
|
| | class LabelMetrics: |
| | """ |
| | Class to store metrics of each split for each label. |
| | """ |
| | def __init__(self) -> None: |
| | """ |
| | Metrics of split, ie 'val' and 'test' |
| | """ |
| | self.val = MetricsData() |
| | self.test = MetricsData() |
| |
|
| | def set_label_metrics(self, split: str, attr: str, value: Union[np.ndarray, float]) -> None: |
| | """ |
| | Set value as appropriate metrics of split. |
| | |
| | Args: |
| | split (str): split |
| | attr (str): attribute name as follows: |
| | classification: 'fpr', 'tpr', or 'auc', |
| | regression: 'y_obs'(ground truth), 'y_pred'(prediction) or 'r2', or |
| | deepsurv: 'c_index' |
| | value (Union[np.ndarray,float]): value of attr |
| | """ |
| | setattr(getattr(self, split), attr, value) |
| |
|
| | def get_label_metrics(self, split: str, attr: str) -> Union[np.ndarray, float]: |
| | """ |
| | Return value of metrics of split. |
| | |
| | Args: |
| | split (str): split |
| | attr (str): metrics name |
| | |
| | Returns: |
| | Union[np.ndarray,float]: value of attr |
| | """ |
| | return getattr(getattr(self, split), attr) |
| |
|
| |
|
| | class ROCMixin: |
| | """ |
| | Class for calculating ROC and AUC. |
| | """ |
| | def _set_roc(self, label_metrics: LabelMetrics, split: str, fpr: np.ndarray, tpr: np.ndarray) -> None: |
| | """ |
| | Set fpr, tpr, and auc. |
| | |
| | Args: |
| | label_metrics (LabelMetrics): metrics of 'val' and 'test' |
| | split (str): 'val' or 'test' |
| | fpr (np.ndarray): FPR |
| | tpr (np.ndarray): TPR |
| | |
| | self.metrics_kind = 'auc' is defined in class ClsEval below. |
| | """ |
| | label_metrics.set_label_metrics(split, 'fpr', fpr) |
| | label_metrics.set_label_metrics(split, 'tpr', tpr) |
| | label_metrics.set_label_metrics(split, self.metrics_kind, metrics.auc(fpr, tpr)) |
| |
|
| | def _cal_label_roc_binary(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics: |
| | """ |
| | Calculate ROC for binary class. |
| | |
| | Args: |
| | label_name (str): label name |
| | df_group (pd.DataFrame): likelihood for group |
| | |
| | Returns: |
| | LabelMetrics: metrics of 'val' and 'test' |
| | """ |
| | required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split'] |
| | df_label = df_group[required_columns] |
| | POSITIVE = 1 |
| | positive_pred_name = 'pred_' + label_name + '_' + str(POSITIVE) |
| |
|
| | |
| | label_metrics = LabelMetrics() |
| | for split in ['val', 'test']: |
| | df_split = df_label.query('split == @split') |
| | y_true = df_split[label_name] |
| | y_score = df_split[positive_pred_name] |
| | _fpr, _tpr, _ = metrics.roc_curve(y_true, y_score) |
| | self._set_roc(label_metrics, split, _fpr, _tpr) |
| | return label_metrics |
| |
|
| | def _cal_label_roc_multi(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics: |
| | """ |
| | Calculate ROC for multi-class by macro average. |
| | |
| | Args: |
| | label_name (str): label name |
| | df_group (pd.DataFrame): likelihood for group |
| | |
| | Returns: |
| | LabelMetrics: metrics of 'val' and 'test' |
| | """ |
| | required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split'] |
| | df_label = df_group[required_columns] |
| |
|
| | pred_name_list = list(df_label.columns[df_label.columns.str.startswith('pred')]) |
| | class_list = [int(pred_name.rsplit('_', 1)[-1]) for pred_name in pred_name_list] |
| | num_classes = len(class_list) |
| |
|
| | label_metrics = LabelMetrics() |
| | for split in ['val', 'test']: |
| | df_split = df_label.query('split == @split') |
| | y_true = df_split[label_name] |
| | y_true_bin = label_binarize(y_true, classes=class_list) |
| |
|
| | |
| | _fpr = dict() |
| | _tpr = dict() |
| | for i, class_name in enumerate(class_list): |
| | pred_name = 'pred_' + label_name + '_' + str(class_name) |
| | _fpr[class_name], _tpr[class_name], _ = metrics.roc_curve(y_true_bin[:, i], df_split[pred_name]) |
| |
|
| | |
| | all_fpr = np.unique(np.concatenate([_fpr[class_name] for class_name in class_list])) |
| |
|
| | |
| | mean_tpr = np.zeros_like(all_fpr) |
| | for class_name in class_list: |
| | mean_tpr += np.interp(all_fpr, _fpr[class_name], _tpr[class_name]) |
| |
|
| | |
| | mean_tpr /= num_classes |
| |
|
| | _fpr['macro'] = all_fpr |
| | _tpr['macro'] = mean_tpr |
| | self._set_roc(label_metrics, split, _fpr['macro'], _tpr['macro']) |
| | return label_metrics |
| |
|
| | def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics: |
| | """ |
| | Calculate ROC and AUC for label depending on binary or multi-class. |
| | |
| | Args: |
| | label_name (str):label name |
| | df_group (pd.DataFrame): likelihood for group |
| | |
| | Returns: |
| | LabelMetrics: metrics of 'val' and 'test' |
| | """ |
| | pred_name_list = df_group.columns[df_group.columns.str.startswith('pred_' + label_name)] |
| | isMultiClass = (len(pred_name_list) > 2) |
| | if isMultiClass: |
| | label_metrics = self._cal_label_roc_multi(label_name, df_group) |
| | else: |
| | label_metrics = self._cal_label_roc_binary(label_name, df_group) |
| | return label_metrics |
| |
|
| |
|
| | class YYMixin: |
| | """ |
| | Class for calculating YY and R2. |
| | """ |
| | def _set_yy(self, label_metrics: LabelMetrics, split: str, y_obs: np.ndarray, y_pred: np.ndarray) -> None: |
| | """ |
| | Set ground truth, prediction, and R2. |
| | |
| | Args: |
| | label_metrics (LabelMetrics): metrics of 'val' and 'test' |
| | split (str): 'val' or 'test' |
| | y_obs (np.ndarray): ground truth |
| | y_pred (np.ndarray): prediction |
| | |
| | self.metrics_kind = 'r2' is defined in class RegEval below. |
| | """ |
| | label_metrics.set_label_metrics(split, 'y_obs', y_obs.values) |
| | label_metrics.set_label_metrics(split, 'y_pred', y_pred.values) |
| | label_metrics.set_label_metrics(split, self.metrics_kind, metrics.r2_score(y_obs, y_pred)) |
| |
|
| | def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics: |
| | """ |
| | Calculate YY and R2 for label. |
| | |
| | Args: |
| | label_name (str): label name |
| | df_group (pd.DataFrame): likelihood for group |
| | |
| | Returns: |
| | LabelMetrics: metrics of 'val' and 'test' |
| | """ |
| | required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split'] |
| | df_label = df_group[required_columns] |
| | label_metrics = LabelMetrics() |
| | for split in ['val', 'test']: |
| | df_split = df_label.query('split == @split') |
| | y_obs = df_split[label_name] |
| | y_pred = df_split['pred_' + label_name] |
| | self._set_yy(label_metrics, split, y_obs, y_pred) |
| | return label_metrics |
| |
|
| |
|
| | class C_IndexMixin: |
| | """ |
| | Class for calculating C-Index. |
| | """ |
| | def _set_c_index( |
| | self, |
| | label_metrics: LabelMetrics, |
| | split: str, |
| | periods: pd.Series, |
| | preds: pd.Series, |
| | labels: pd.Series |
| | ) -> None: |
| | """ |
| | Set C-Index. |
| | |
| | Args: |
| | label_metrics (LabelMetrics): metrics of 'val' and 'test' |
| | split (str): 'val' or 'test' |
| | periods (pd.Series): periods |
| | preds (pd.Series): prediction |
| | labels (pd.Series): label |
| | |
| | self.metrics_kind = 'c_index' is defined in class DeepSurvEval below. |
| | """ |
| | from lifelines.utils import concordance_index |
| | value_c_index = concordance_index(periods, (-1)*preds, labels) |
| | label_metrics.set_label_metrics(split, self.metrics_kind, value_c_index) |
| |
|
| | def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics: |
| | """ |
| | Calculate C-Index for label. |
| | |
| | Args: |
| | label_name (str): label name |
| | df_group (pd.DataFrame): likelihood for group |
| | |
| | Returns: |
| | LabelMetrics: metrics of 'val' and 'test' |
| | """ |
| | required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['periods', 'split'] |
| | df_label = df_group[required_columns] |
| | label_metrics = LabelMetrics() |
| | for split in ['val', 'test']: |
| | df_split = df_label.query('split == @split') |
| | periods = df_split['periods'] |
| | preds = df_split['pred_' + label_name] |
| | labels = df_split[label_name] |
| | self._set_c_index(label_metrics, split, periods, preds, labels) |
| | return label_metrics |
| |
|
| |
|
| | class MetricsMixin: |
| | """ |
| | Class to calculate metrics and make summary. |
| | """ |
| | def _cal_group_metrics(self, df_group: pd.DataFrame) -> Dict[str, LabelMetrics]: |
| | """ |
| | Calculate metrics for each group. |
| | |
| | Args: |
| | df_group (pd.DataFrame): likelihood for group |
| | |
| | Returns: |
| | Dict[str, LabelMetrics]: dictionary of label and its LabelMetrics |
| | eg. {{label_1: LabelMetrics(), label_2: LabelMetrics(), ...} |
| | """ |
| | label_list = list(df_group.columns[df_group.columns.str.startswith('label')]) |
| | group_metrics = dict() |
| | for label_name in label_list: |
| | label_metrics = self.cal_label_metrics(label_name, df_group) |
| | group_metrics[label_name] = label_metrics |
| | return group_metrics |
| |
|
| | def cal_whole_metrics(self, df_likelihood: pd.DataFrame) -> Dict[str, Dict[str, LabelMetrics]]: |
| | """ |
| | Calculate metrics for all groups. |
| | |
| | Args: |
| | df_likelihood (pd.DataFrame) : DataFrame of likelihood |
| | |
| | Returns: |
| | Dict[str, Dict[str, LabelMetrics]]: dictionary of group and dictionary of label and its LabelMetrics |
| | eg. { |
| | groupA: {label_1: LabelMetrics(), label_2: LabelMetrics(), ...}, |
| | groupB: {label_1: LabelMetrics(), label_2: LabelMetrics()}, ...}, |
| | ...} |
| | """ |
| | whole_metrics = dict() |
| | for group in df_likelihood['group'].unique(): |
| | df_group = df_likelihood.query('group == @group') |
| | whole_metrics[group] = self._cal_group_metrics(df_group) |
| | return whole_metrics |
| |
|
| | def make_summary( |
| | self, |
| | whole_metrics: Dict[str, Dict[str, LabelMetrics]], |
| | likelihood_path: Path, |
| | metrics_kind: str |
| | ) -> pd.DataFrame: |
| | """ |
| | Make summary. |
| | |
| | Args: |
| | whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups |
| | likelihood_path (Path): path to likelihood |
| | metrics_kind (str): kind of metrics, ie, 'auc', 'r2', or 'c_index' |
| | |
| | Returns: |
| | pd.DataFrame: summary |
| | """ |
| | _datetime = likelihood_path.parents[1].name |
| | _weight = likelihood_path.stem.replace('likelihood_', '') + '.pt' |
| | df_summary = pd.DataFrame() |
| | for group, group_metrics in whole_metrics.items(): |
| | _new = dict() |
| | _new['datetime'] = [_datetime] |
| | _new['weight'] = [ _weight] |
| | _new['group'] = [group] |
| | for label_name, label_metrics in group_metrics.items(): |
| | _val_metrics = label_metrics.get_label_metrics('val', metrics_kind) |
| | _test_metrics = label_metrics.get_label_metrics('test', metrics_kind) |
| | _new[label_name + '_val_' + metrics_kind] = [f"{_val_metrics:.2f}"] |
| | _new[label_name + '_test_' + metrics_kind] = [f"{_test_metrics:.2f}"] |
| | df_summary = pd.concat([df_summary, pd.DataFrame(_new)], ignore_index=True) |
| |
|
| | df_summary = df_summary.sort_values('group') |
| | return df_summary |
| |
|
| | def print_metrics(self, df_summary: pd.DataFrame, metrics_kind: str) -> None: |
| | """ |
| | Print metrics. |
| | |
| | Args: |
| | df_summary (pd.DataFrame): summary |
| | metrics_kind (str): kind of metrics, ie. 'auc', 'r2', or 'c_index' |
| | """ |
| | label_list = list(df_summary.columns[df_summary.columns.str.startswith('label')]) |
| | num_splits = len(['val', 'test']) |
| | _column_val_test_list = [label_list[i:i+num_splits] for i in range(0, len(label_list), num_splits)] |
| | for _, row in df_summary.iterrows(): |
| | logger.info(row['group']) |
| | for _column_val_test in _column_val_test_list: |
| | _label_name = _column_val_test[0].replace('_val', '') |
| | _label_name_val = _column_val_test[0] |
| | _label_name_test = _column_val_test[1] |
| | logger.info(f"{_label_name:<25} val_{metrics_kind}: {row[_label_name_val]:>7}, test_{metrics_kind}: {row[_label_name_test]:>7}") |
| |
|
| | def update_summary(self, df_summary: pd.DataFrame, likelihood_path: Path) -> None: |
| | """ |
| | Update summary. |
| | |
| | Args: |
| | df_summary (pd.DataFrame): summary to be added to the previous summary |
| | likelihood_path (Path): path to likelihood |
| | """ |
| | _project_dir = likelihood_path.parents[3] |
| | summary_dir = Path(_project_dir, 'summary') |
| | summary_path = Path(summary_dir, 'summary.csv') |
| | if summary_path.exists(): |
| | df_prev = pd.read_csv(summary_path) |
| | df_updated = pd.concat([df_prev, df_summary], axis=0) |
| | else: |
| | summary_dir.mkdir(parents=True, exist_ok=True) |
| | df_updated = df_summary |
| | df_updated.to_csv(summary_path, index=False) |
| |
|
| | def make_metrics(self, likelihood_path: Path) -> None: |
| | """ |
| | Make metrics. |
| | |
| | Args: |
| | likelihood_path (Path): path to likelihood |
| | """ |
| | df_likelihood = pd.read_csv(likelihood_path) |
| | whole_metrics = self.cal_whole_metrics(df_likelihood) |
| | self.make_save_fig(whole_metrics, likelihood_path, self.fig_kind) |
| | df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind) |
| | self.print_metrics(df_summary, self.metrics_kind) |
| | self.update_summary(df_summary, likelihood_path) |
| |
|
| |
|
| | class FigROCMixin: |
| | """ |
| | Class to plot ROC. |
| | """ |
| | def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt: |
| | """ |
| | Plot ROC. |
| | |
| | Args: |
| | group (str): group |
| | group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics |
| | |
| | Returns: |
| | plt: ROC |
| | """ |
| | label_list = group_metrics.keys() |
| | num_rows = 1 |
| | num_cols = len(label_list) |
| | base_size = 7 |
| | height = num_rows * base_size |
| | width = num_cols * height |
| | fig = plt.figure(figsize=(width, height)) |
| |
|
| | for i, label_name in enumerate(label_list): |
| | label_metrics = group_metrics[label_name] |
| | offset = i + 1 |
| | ax_i = fig.add_subplot( |
| | num_rows, |
| | num_cols, |
| | offset, |
| | title=group + ': ' + label_name, |
| | xlabel='1 - Specificity', |
| | ylabel='Sensitivity', |
| | xmargin=0, |
| | ymargin=0 |
| | ) |
| | ax_i.plot(label_metrics.val.fpr, label_metrics.val.tpr, label=f"AUC_val = {label_metrics.val.auc:.2f}", marker='x') |
| | ax_i.plot(label_metrics.test.fpr, label_metrics.test.tpr, label=f"AUC_test = {label_metrics.test.auc:.2f}", marker='o') |
| | ax_i.grid() |
| | ax_i.legend() |
| | fig.tight_layout() |
| | return fig |
| |
|
| |
|
| | class FigYYMixin: |
| | """ |
| | Class to plot YY-graph. |
| | """ |
| | def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt: |
| | """ |
| | Plot yy. |
| | |
| | Args: |
| | group (str): group |
| | group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics |
| | |
| | Returns: |
| | plt: YY-graph |
| | """ |
| | label_list = group_metrics.keys() |
| | num_splits = len(['val', 'test']) |
| | num_rows = 1 |
| | num_cols = len(label_list) * num_splits |
| | base_size = 7 |
| | height = num_rows * base_size |
| | width = num_cols * height |
| | fig = plt.figure(figsize=(width, height)) |
| |
|
| | for i, label_name in enumerate(label_list): |
| | label_metrics = group_metrics[label_name] |
| | val_offset = (i * num_splits) + 1 |
| | test_offset = val_offset + 1 |
| |
|
| | val_ax = fig.add_subplot( |
| | num_rows, |
| | num_cols, |
| | val_offset, |
| | title=group + ': ' + label_name + '\n' + 'val: Observed-Predicted Plot', |
| | xlabel='Observed', |
| | ylabel='Predicted', |
| | xmargin=0, |
| | ymargin=0 |
| | ) |
| |
|
| | test_ax = fig.add_subplot( |
| | num_rows, |
| | num_cols, |
| | test_offset, |
| | title=group + ': ' + label_name + '\n' + 'test: Observed-Predicted Plot', |
| | xlabel='Observed', |
| | ylabel='Predicted', |
| | xmargin=0, |
| | ymargin=0 |
| | ) |
| |
|
| | y_obs_val = label_metrics.val.y_obs |
| | y_pred_val = label_metrics.val.y_pred |
| |
|
| | y_obs_test = label_metrics.test.y_obs |
| | y_pred_test = label_metrics.test.y_pred |
| |
|
| | |
| | color = mcolors.TABLEAU_COLORS |
| | val_ax.scatter(y_obs_val, y_pred_val, color=color['tab:blue'], label='val') |
| | test_ax.scatter(y_obs_test, y_pred_test, color=color['tab:orange'], label='test') |
| |
|
| | |
| | y_values_val = np.concatenate([y_obs_val.flatten(), y_pred_val.flatten()]) |
| | y_values_test = np.concatenate([y_obs_test.flatten(), y_pred_test.flatten()]) |
| |
|
| | y_values_val_min, y_values_val_max, y_values_val_range = np.amin(y_values_val), np.amax(y_values_val), np.ptp(y_values_val) |
| | y_values_test_min, y_values_test_max, y_values_test_range = np.amin(y_values_test), np.amax(y_values_test), np.ptp(y_values_test) |
| |
|
| | val_ax.plot([y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)], |
| | [y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)], color='red') |
| |
|
| | test_ax.plot([y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)], |
| | [y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)], color='red') |
| |
|
| | fig.tight_layout() |
| | return fig |
| |
|
| |
|
| | class FigMixin: |
| | """ |
| | Class for make and save figure |
| | This class is for ROC and YY-graph. |
| | """ |
| | def make_save_fig(self, whole_metrics: Dict[str, Dict[str, LabelMetrics]], likelihood_path: Path, fig_kind: str) -> None: |
| | """ |
| | Make and save figure. |
| | |
| | Args: |
| | whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups |
| | likelihood_path (Path): path to likelihood |
| | fig_kind (str): kind of figure, ie. 'roc' or 'yy' |
| | """ |
| | _datetime_dir = likelihood_path.parents[1] |
| | save_dir = Path(_datetime_dir, fig_kind) |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | _fig_name = fig_kind + '_' + likelihood_path.stem.replace('likelihood_', '') |
| | for group, group_metrics in whole_metrics.items(): |
| | fig = self._plot_fig_group_metrics(group, group_metrics) |
| | save_path = Path(save_dir, group + '_' + _fig_name + '.png') |
| | fig.savefig(save_path) |
| | plt.close() |
| |
|
| |
|
| | class ClsEval(MetricsMixin, ROCMixin, FigMixin, FigROCMixin): |
| | """ |
| | Class for calculation metrics for classification. |
| | """ |
| | def __init__(self) -> None: |
| | self.fig_kind = 'roc' |
| | self.metrics_kind = 'auc' |
| |
|
| |
|
| | class RegEval(MetricsMixin, YYMixin, FigMixin, FigYYMixin): |
| | """ |
| | Class for calculation metrics for regression. |
| | """ |
| | def __init__(self) -> None: |
| | self.fig_kind = 'yy' |
| | self.metrics_kind = 'r2' |
| |
|
| |
|
| | class DeepSurvEval(MetricsMixin, C_IndexMixin): |
| | """ |
| | Class for calculation metrics for DeepSurv. |
| | """ |
| | def __init__(self) -> None: |
| | self.fig_kind = None |
| | self.metrics_kind = 'c_index' |
| |
|
| | def make_metrics(self, likelihood_path: Path) -> None: |
| | """ |
| | Make metrics, substantially this method handles everything all. |
| | |
| | Args: |
| | likelihood_path (Path): path to likelihood |
| | |
| | Overwrite def make_metrics() in class MetricsMixin by deleting self.make_save_fig(), |
| | because of no need to plot and save figure. |
| | """ |
| | df_likelihood = pd.read_csv(likelihood_path) |
| | whole_metrics = self.cal_whole_metrics(df_likelihood) |
| | df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind) |
| | self.print_metrics(df_summary, self.metrics_kind) |
| | self.update_summary(df_summary, likelihood_path) |
| |
|
| |
|
| | def set_eval(task: str) -> Union[ClsEval, RegEval, DeepSurvEval]: |
| | """ |
| | Set class for evaluation depending on task depending on task. |
| | |
| | Args: |
| | task (str): task |
| | |
| | Returns: |
| | Union[ClsEval, RegEval, DeepSurvEval]: class for evaluation |
| | """ |
| | if task == 'classification': |
| | return ClsEval() |
| | elif task == 'regression': |
| | return RegEval() |
| | elif task == 'deepsurv': |
| | return DeepSurvEval() |
| | else: |
| | raise ValueError(f"Invalid task: {task}.") |
| |
|