File size: 12,431 Bytes
e487255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import time
import os
from pathlib import Path

from tqdm import tqdm
import random
import numpy as np

from torch import nn

from utils import torch_nanmean
from datasets import *
from model_builder import load_model
from scripts.tabular_baselines import get_scoring_string
from scripts import tabular_metrics
from scripts.transformer_prediction_interface import *
from scripts.baseline_prediction_interface import *
"""
===============================
PUBLIC FUNCTIONS FOR EVALUATION
===============================
"""


def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
    metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
    metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
    return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}

def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):

    # How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
    def check_file(e):
        model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
        model_path = os.path.join(base_path, model_file)
        # print('Evaluate ', model_path)
        results_file = os.path.join(base_path,
                                    f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
        if not Path(model_path).is_file():  # or Path(results_file).is_file():
            # print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
            return None, None, None
        return model_file, model_path, results_file

    if e == -1: # use last checkpoint, if e == -1
        for e_ in range(100, -1, -1):
            model_file_, model_path_, results_file_ = check_file(e_)
            if model_file_ is not None:
                e = e_
                model_file, model_path, results_file = model_file_, model_path_, results_file_
                break
    else:
        model_file, model_path, results_file = check_file(e)

    model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
    print(model[2].style_encoder)

    params = {'max_features': config_sample['num_features']
        , 'rescale_features': config_sample["normalize_by_used_features"]
        , 'normalize_to_ranking': config_sample["normalize_to_ranking"]
        , 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
              }
    metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
                             extend_features=True
                             # just removed the style keyword but transformer is trained with style, just empty
                             , save=False
                             , metric_used=tabular_metrics.cross_entropy
                             , return_tensor=True
                             , verbose=False
                             , eval_positions=eval_positions
                             , bptt=bptt
                             , base_path=None
                             , inference_mode=True
                             , **params
                             , **kwargs)

    tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
    tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)

    return metrics_valid, config_sample, model_path


def evaluate(datasets, bptt, eval_positions, metric_used, model
             , verbose=False
             , return_tensor=False
             , **kwargs):
    """
    Evaluates a list of datasets for a model function.

    :param datasets: List of datasets
    :param bptt: maximum sequence length
    :param eval_positions: List of positions where to evaluate models
    :param verbose: If True, is verbose.
    :param metric_used: Which metric is optimized for.
    :param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
    :param kwargs:
    :return:
    """
    overall_result = {'metric_used': get_scoring_string(metric_used)
                      , 'bptt': bptt
                      , 'eval_positions': eval_positions}

    aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0

    # For each dataset
    for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
        dataset_bptt = min(len(X), bptt)
        # if verbose and dataset_bptt < bptt:
        #    print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')

        aggregated_metric, num = torch.tensor(0.0), 0
        ds_result = {}

        for eval_position in (eval_positions if verbose else eval_positions):
            eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
            eval_position_bptt = int(eval_position_real * 2.0)

            r = evaluate_position(X, y, model=model
                        , num_classes=len(torch.unique(y))
                        , categorical_feats = categorical_feats
                        , bptt = eval_position_bptt
                        , ds_name=ds_name
                        , eval_position = eval_position_real
                        , metric_used = metric_used
                        ,**kwargs)

            if r is None:
                continue

            _, outputs, ys, best_configs, time_used = r

            if torch.is_tensor(outputs):
                outputs = outputs.to(outputs.device)
                ys = ys.to(outputs.device)

            ys = ys.T
            ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
            ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
            ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
            ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used

            new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))

            if not return_tensor:
                make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
                new_metric = make_scalar(new_metric)
                ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}

            lib = torch if return_tensor else np
            if not lib.isnan(new_metric).any():
                aggregated_metric, num = aggregated_metric + new_metric, num + 1

        overall_result.update(ds_result)
        if num > 0:
            aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1

    overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets

    return overall_result

"""
===============================
INTERNAL HELPER FUNCTIONS
===============================
"""

def check_file_exists(path):
    """Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
    if (os.path.isfile(path)):
        print(f'loading results from {path}')
        with open(path, 'rb') as f:
            return np.load(f, allow_pickle=True).tolist()
    return None

def generate_valid_split(X, y, bptt, eval_position, split_number=1):
    """Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
    the entire datasets. If no such split can be sampled in 7 passes, returns None.

    :param X: torch tensor, feature values
    :param y: torch tensor, class values
    :param bptt: Number of samples in train + test
    :param eval_position: Number of samples in train, i.e. from which index values are in test
    :param split_number: The split id
    :return:
    """
    done, seed = False, 13

    torch.manual_seed(split_number)
    perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
    X, y = X[perm], y[perm]

    while not done:
        if seed > 20:
            return None, None # No split could be generated in 7 passes, return None
        random.seed(seed)
        i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
        y_ = y[i:i + bptt]

        # Checks if all classes from dataset are contained and classes in train and test are equal (contain same
        # classes) and
        done = len(torch.unique(y_)) == len(torch.unique(y))
        done = done and torch.all(torch.unique(y_) == torch.unique(y))
        done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
        done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
        seed = seed + 1

    eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
    eval_ys = torch.stack([y[i:i + bptt].clone()], 1)

    return eval_xs, eval_ys


def evaluate_position(X, y, categorical_feats, model, bptt
                      , eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
                      , max_time=300, split_number=1
                      , per_step_normalization=False, **kwargs):
    """
    Evaluates a dataset with a 'bptt' number of training samples.

    :param X: Dataset X
    :param y: Dataset labels
    :param categorical_feats: Indices of categorical features.
    :param model: Model function
    :param bptt: Sequence length.
    :param eval_position: Number of training samples.
    :param overwrite: Wheater to ove
    :param overwrite: If True, results on disk are overwritten.
    :param save:
    :param path_interfix: Used for constructing path to write on disk.
    :param method: Model name.
    :param ds_name: Datset name.
    :param fetch_only: Wheater to calculate or only fetch results.
    :param per_step_normalization:
    :param kwargs:
    :return:
    """

    if save:
        path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
        #log_path =

    ## Load results if on disk
    if not overwrite:
        result = check_file_exists(path)
        if result is not None:
            if not fetch_only:
                print(f'Loaded saved result for {path}')
            return result
        elif fetch_only:
            print(f'Could not load saved result for {path}')
            return None

    ## Generate data splits
    eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
    if eval_xs is None:
        return None
        print(f"No dataset could be generated {ds_name} {bptt}")

    eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)

    start_time = time.time()

    if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
        outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
    else:
        _, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
                                                    , eval_pos=eval_position
                                                    , max_time=max_time, **kwargs)

    eval_ys = eval_ys[eval_position:]
    if outputs is None:
        return None

    if torch.is_tensor(outputs): # Transfers data to cpu for saving
        outputs = outputs.cpu()
        eval_ys = eval_ys.cpu()

    ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time

    if save:
        with open(path, 'wb') as f:
            np.save(f, ds_result)
            print(f'saved results to {path}')

    return ds_result