File size: 5,494 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177


import os
import sys
import argparse

import numpy as np

from tqdm import tqdm

import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')

# load utility functions
from evaluation import eval_utils

from util import utils
from util import evaluation


def calculate_t3_intrinsic_dims(
        model_name,
        model,
        tok,
        hparams,
        edit_mode,
        theta,
        num_aug,
        layers,
        save_path,
        output_path,
        augmented_cache = None,
        cache_features = False,
    ):
    """ Theorem 3 intrinsic dimensionality of augmented prompt features for multiple samples.
    """
    # load activation function
    activation = utils.load_activation(hparams['activation'])

    # find unique pickle files
    pickle_paths = np.array([
        f for f in utils.path_all_files(save_path) \
            if f.endswith('.pickle') and ('perplexity' not in f)
    ])
    _, unique_indices = np.unique(
        np.array([os.path.basename(f) for f in pickle_paths]), return_index=True)

    pickle_paths = pickle_paths[unique_indices]
    pickle_paths = utils.shuffle_list(pickle_paths)
    print('Number of pickle files:', len(pickle_paths))

    for sample_idx in tqdm(range(len(pickle_paths))):

        try:

            # find sample file
            edit_contents = utils.loadpickle(pickle_paths[sample_idx])
            case_id = edit_contents['case_id']
            
            output_file = os.path.join(output_path, f'{case_id}.pickle')
            if os.path.exists(output_file):
                print('Already exists:', output_file)
                continue

            # extract features and calculate intrinsic dims
            layer_features, layer_masks, intrinsic_dims = eval_utils.sample_t3_intrinsic_dims(
                model,
                tok,
                hparams,
                layers = layers,
                request = edit_contents['request'],
                edit_mode = edit_mode,
                num_aug = num_aug,
                theta = theta,
                augmented_cache = augmented_cache,
                verbose = False
            )

            # calculate false positive rates
            fpr_raw, fpr_ftd = eval_utils.calculate_fpr(
                model_name,
                layers,
                save_path,
                case_id,
                activation,
                layer_features,
                layer_masks,
                num_aug
            )

            # save results
            to_save = {'intrinsic_dims': intrinsic_dims}
            to_save['layer_indices'] = layers
            to_save['fpr_raw'] = fpr_raw
            to_save['fpr_ftd'] = fpr_ftd
            to_save['num_aug'] = num_aug

            to_save['num_filtered'] = [np.sum(layer_masks[l]) for l in layers]

            if cache_features:
                to_save['layer_features'] = layer_features
                to_save['layer_masks'] = layer_masks

            utils.savepickle(output_file, to_save)
        
        except:
            print('Error:', case_id)
            continue

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        '--model', default="gpt-j-6b", type=str, help='model to edit')
    parser.add_argument(
        '--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')

    parser.add_argument(
        '--edit_mode', 
        choices=['prompt', 'context', 'wikipedia'],
        default='in-place', 
        help='mode of edit/attack to execute'
    )
    parser.add_argument(
        '--num_aug', default=2000, type=int, help='layer for basis edits')
    parser.add_argument(
        '--static_context', type=str, default=None, help='output directory')
    parser.add_argument(
        '--augmented_cache', type=str, default=None, help='output directory')

    parser.add_argument(
        '--theta', default=0.005, type=float, help='theta for intrinsic dim calculation')

    parser.add_argument(
        '--cache_features', default=0, type=int, help='boolean switch to cache features')

    parser.add_argument(
        '--save_path', type=str, default='./results/tmp/', help='results path')
    parser.add_argument(
        '--output_path', type=str, default='./results/dimensionality/', help='results path')

    args = parser.parse_args()

    # boolean arguments
    args.cache_features = bool(args.cache_features)

    # loading hyperparameters
    hparams_path = f'./hparams/SE/{args.model}.json'
    hparams = utils.loadjson(hparams_path)

    if args.static_context is not None:
        hparams['static_context'] = args.static_context

    # ensure results path exists
    args.save_path = os.path.join(args.save_path, f'{args.dataset}/{args.model}/')
    args.output_path = os.path.join(args.output_path, f'{args.edit_mode}/{args.dataset}/{args.model}/')
    utils.assure_path_exists(args.output_path)

    # load model and tokenizer
    model, tok = utils.load_model_tok(model_name=args.model)

    # calculate intrinsic dims
    calculate_t3_intrinsic_dims(
        args.model,
        model,
        tok,
        hparams,
        edit_mode = args.edit_mode,
        theta = args.theta,
        num_aug = args.num_aug,
        layers = evaluation.model_layer_indices[args.model],
        save_path = args.save_path,
        output_path = args.output_path,
        augmented_cache=args.augmented_cache,
        cache_features = args.cache_features
    )