import os import torch import numpy as np import time import pickle from scripts import tabular_metrics from scripts.tabular_metrics import calculate_score_per_method from scripts.tabular_evaluation import evaluate from priors.differentiable_prior import draw_random_style from tqdm import tqdm import random from scripts.transformer_prediction_interface import get_params_from_config, load_model_workflow """ =============================== PUBLIC FUNCTIONS FOR EVALUATION =============================== """ def eval_model_range(i_range, *args, **kwargs): for i in i_range: eval_model(i, *args, **kwargs) def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test, bptt_valid, bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args): """ Differentiable model evaliation workflow. Evaluates and saves results to disk. :param i: :param e: :param valid_datasets: :param test_datasets: :param train_datasets: :param eval_positions_valid: :param eval_positions_test: :param bptt_valid: :param bptt_test: :param add_name: :param base_path: :param device: :param eval_addition: :param extra_tuning_args: :return: """ model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition) params = {'bptt': bptt_valid , 'bptt_final': bptt_test , 'eval_positions': eval_positions_valid , 'eval_positions_test': eval_positions_test , 'valid_datasets': valid_datasets , 'test_datasets': test_datasets , 'train_datasets': train_datasets , 'verbose': True , 'device': device } params.update(get_params_from_config(c)) start = time.time() metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params, **extra_tuning_args) print('Evaluation time: ', time.time() - start) print(results_file) r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route] with open(results_file, 'wb') as output: del r[0]['num_features_used'] del r[0]['categorical_features_sampler'] pickle.dump(r, output) _, _, _, style, temperature, _ = r return r, model """ =============================== INTERNAL HELPER FUNCTIONS =============================== """ def evaluate_differentiable_model(model , valid_datasets , test_datasets , train_datasets , N_draws=100 , N_grad_steps=10 , eval_positions=None , eval_positions_test=None , bptt=100 , bptt_final=200 , style=None , n_parallel_configurations=1 , device='cpu' , selection_metric='auc' , final_splits=[1, 2, 3, 4, 5] , N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100] , **kwargs): """ Evaluation function for diffable model evaluation. Returns a list of results. :param model: :param valid_datasets: :param test_datasets: :param train_datasets: :param N_draws: :param N_grad_steps: :param eval_positions: :param eval_positions_test: :param bptt: :param bptt_final: :param style: :param n_parallel_configurations: :param device: :param selection_metric: :param final_splits: :param N_ensemble_configurations_list: :param kwargs: :return: """ torch.manual_seed(0) np.random.seed(0) random.seed(0) diffable_metric = tabular_metrics.cross_entropy evaluation_metric = tabular_metrics.auc_metric if selection_metric in ('auc', 'roc'): selection_metric_min_max = 'max' selection_metric = tabular_metrics.auc_metric evaluation_metric = selection_metric elif selection_metric in ('ce', 'selection_metric'): selection_metric_min_max = 'min' selection_metric = tabular_metrics.cross_entropy evaluation_metric = selection_metric print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric', evaluation_metric) print('N PARALLEL CONFIGURATIONS', n_parallel_configurations) print('eval_positions', eval_positions) def evaluate_valid(style, softmax_temperature, results, results_tracked): result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature, return_tensor=False, inference_mode=True, selection_metric=selection_metric, evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2]) result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions] results += [result_valid] results_tracked += [np.nanmean(result_valid)] model[2].to(device) model[2].eval() results_on_valid, results_on_valid_tracked = [], [] best_style, best_softmax_temperature = style, torch.cat( [torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0) optimization_routes = [] best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)], 0) best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0) for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)], 0) softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0) evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked) print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}') if N_grad_steps > 0: gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps , softmax_temperature=softmax_temperature , model=model[2] , train_datasets=train_datasets , valid_datasets=valid_datasets , selection_metric_min_max=selection_metric_min_max , **kwargs) optimization_routes += [gradient_optimize_result['optimization_route']] evaluate_valid(gradient_optimize_result['best_style'] , gradient_optimize_result['best_temperature'] , results_on_valid, results_on_valid_tracked) print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}') if selection_metric_min_max == 'min': is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked)) else: is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked)) if is_best or best_style is None: best_style = gradient_optimize_result['best_style'].clone() best_softmax_temperature = gradient_optimize_result['best_temperature'].clone() torch.cuda.empty_cache() def final_evaluation(): print('Running eval dataset with final params (no gradients)..') print(best_style, best_softmax_temperature) result_test = [] for N_ensemble_configurations in N_ensemble_configurations_list: print(f'Running with {N_ensemble_configurations} ensemble_configurations') kwargs['N_ensemble_configurations'] = N_ensemble_configurations splits = [] for split in final_splits: splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature , return_tensor=False, eval_positions=eval_positions_test, bptt=bptt_final, inference_mode=True, split_number=split, model=model[2] , selection_metric=selection_metric, evaluation_metric=evaluation_metric)] result_test += [splits] print('Running valid dataset with final params (no gradients)..') result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature , return_tensor=False, eval_positions=eval_positions_test, bptt=bptt_final, inference_mode=True, model=model[2] , selection_metric=selection_metric, evaluation_metric=evaluation_metric) return result_test, result_valid result_test, result_valid = final_evaluation() return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs): def step(): return evaluate(datasets=ds, method='transformer' , overwrite=True , style=used_style , eval_positions=eval_positions , metric_used=selection_metric , save=False , path_interfix=None , base_path=None , verbose=True , **kwargs) if return_tensor: r = step() else: with torch.no_grad(): r = step() calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean') calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean') return r def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False, limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs): """ Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'. :param model: :param init_style: :param steps: :param learning_rate: :param softmax_temperature: :param train_datasets: :param valid_datasets: :param optimize_all: :param limit_style: :param N_datasets_sampled: :param optimize_softmax_temperature: :param selection_metric_min_max: :param kwargs: :return: """ grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True) best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature) variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature] optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate) optimization_route_selection, optimization_route_diffable = [], [] optimization_route_selection_valid, optimization_route_diffable_valid = [], [] def eval_opt(ds, return_tensor=True, inference_mode=False): result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor , inference_mode=inference_mode, model=model[2], **kwargs) diffable_metric = result['mean_metric'] selection_metric = result['mean_select'] return diffable_metric, selection_metric def eval_all_datasets(datasets, propagate=True): selection_metrics_this_step, diffable_metrics_this_step = [], [] for ds in datasets: diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate)) if not torch.isnan(diffable_metric_train).any(): if propagate and diffable_metric_train.requires_grad == True: diffable_metric_train.backward() selection_metrics_this_step += [selection_metric_train] diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())] diffable_metric_train = np.nanmean(diffable_metrics_this_step) selection_metric_train = np.nanmean(selection_metrics_this_step) return diffable_metric_train, selection_metric_train for t in tqdm(range(steps), desc='Iterate over Optimization steps'): optimizer.zero_grad() # Select subset of datasets random.seed(t) train_datasets_ = random.sample(train_datasets, N_datasets_sampled) # Get score on train diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True) optimization_route_selection += [float(selection_metric_train)] optimization_route_diffable += [float(diffable_metric_train)] # Get score on valid diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False) optimization_route_selection_valid += [float(selection_metric_valid)] optimization_route_diffable_valid += [float(diffable_metric_valid)] is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid) is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid) if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best): print('New best', best_selection_metric, selection_metric_valid) best_style = grad_style.detach().clone() best_temperature = softmax_temperature.detach().clone() best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid optimizer.step() if limit_style: grad_style = grad_style.detach().clamp(-1.74, 1.74) print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' + f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}') print(f'Return best:{best_style} {best_selection_metric}') return {'best_style': best_style, 'best_temperature': best_temperature , 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable, 'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}