File size: 4,747 Bytes
4943752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d65ddc0
4943752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""

EvalModelCommand class
==============================

"""


from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import dataclass

import scipy
import torch

import textattack
from textattack import DatasetArgs, ModelArgs
from textattack.commands import TextAttackCommand
from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS

logger = textattack.shared.utils.logger


def _cb(s):
    return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")


@dataclass
class ModelEvalArgs(ModelArgs, DatasetArgs):
    random_seed: int = 765
    batch_size: int = 32
    num_examples: int = 5
    num_examples_offset: int = 0


class EvalModelCommand(TextAttackCommand):
    """The TextAttack model benchmarking module:

    A command line parser to evaluatate a model from user
    specifications.
    """

    def get_preds(self, model, inputs):
        with torch.no_grad():
            preds = textattack.shared.utils.batch_model_predict(model, inputs)
        return preds

    def test_model_on_dataset(self, args):
        model = ModelArgs._create_model_from_args(args)
        dataset = DatasetArgs._create_dataset_from_args(args)
        if args.num_examples == -1:
            args.num_examples = len(dataset)

        preds = []
        ground_truth_outputs = []
        i = 0
        while i < min(args.num_examples, len(dataset)):
            dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)]
            batch_inputs = []
            for text_input, ground_truth_output in dataset_batch:
                attacked_text = textattack.shared.AttackedText(text_input)
                batch_inputs.append(attacked_text.tokenizer_input)
                ground_truth_outputs.append(ground_truth_output)
            batch_preds = model(batch_inputs)

            if not isinstance(batch_preds, torch.Tensor):
                batch_preds = torch.Tensor(batch_preds)

            preds.extend(batch_preds)
            i += args.batch_size

        preds = torch.stack(preds).squeeze().cpu()
        ground_truth_outputs = torch.tensor(ground_truth_outputs).cpu()

        logger.info(f"Got {len(preds)} predictions.")

        if preds.ndim == 1:
            # if preds is just a list of numbers, assume regression for now
            # TODO integrate with `textattack.metrics` package
            pearson_correlation, _ = scipy.stats.pearsonr(ground_truth_outputs, preds)
            spearman_correlation, _ = scipy.stats.spearmanr(ground_truth_outputs, preds)

            logger.info(f"Pearson correlation = {_cb(pearson_correlation)}")
            logger.info(f"Spearman correlation = {_cb(spearman_correlation)}")
        else:
            guess_labels = preds.argmax(dim=1)
            successes = (guess_labels == ground_truth_outputs).sum().item()
            perc_accuracy = successes / len(preds) * 100.0
            perc_accuracy = "{:.2f}%".format(perc_accuracy)
            logger.info(f"Correct {successes}/{len(preds)} ({_cb(perc_accuracy)})")

    def run(self, args):
        args = ModelEvalArgs(**vars(args))
        textattack.shared.utils.set_seed(args.random_seed)

        # Default to 'all' if no model chosen.
        if not (args.model or args.model_from_huggingface or args.model_from_file):
            for model_name in list(HUGGINGFACE_MODELS.keys()) + list(
                TEXTATTACK_MODELS.keys()
            ):
                args.model = model_name
                self.test_model_on_dataset(args)
                logger.info("-" * 50)
        else:
            self.test_model_on_dataset(args)

    @staticmethod
    def register_subcommand(main_parser: ArgumentParser):
        parser = main_parser.add_parser(
            "eval",
            help="evaluate a model with TextAttack",
            formatter_class=ArgumentDefaultsHelpFormatter,
        )

        parser = ModelArgs._add_parser_args(parser)
        parser = DatasetArgs._add_parser_args(parser)

        parser.add_argument("--random-seed", default=765, type=int)
        parser.add_argument(
            "--batch-size",
            type=int,
            default=32,
            help="The batch size for evaluating the model.",
        )
        parser.add_argument(
            "--num-examples",
            "-n",
            type=int,
            required=False,
            default=5,
            help="The number of examples to process, -1 for entire dataset",
        )
        parser.add_argument(
            "--num-examples-offset",
            "-o",
            type=int,
            required=False,
            default=0,
            help="The offset to start at in the dataset.",
        )

        parser.set_defaults(func=EvalModelCommand())