Spaces:
Sleeping
Sleeping
""" | |
EvalModelCommand class | |
============================== | |
""" | |
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
from dataclasses import dataclass | |
import scipy | |
import torch | |
import textattack | |
from textattack import DatasetArgs, ModelArgs | |
from textattack.commands import TextAttackCommand | |
from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS | |
logger = textattack.shared.utils.logger | |
def _cb(s): | |
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | |
class ModelEvalArgs(ModelArgs, DatasetArgs): | |
random_seed: int = 765 | |
batch_size: int = 32 | |
num_examples: int = 5 | |
num_examples_offset: int = 0 | |
class EvalModelCommand(TextAttackCommand): | |
"""The TextAttack model benchmarking module: | |
A command line parser to evaluatate a model from user | |
specifications. | |
""" | |
def get_preds(self, model, inputs): | |
with torch.no_grad(): | |
preds = textattack.shared.utils.batch_model_predict(model, inputs) | |
return preds | |
def test_model_on_dataset(self, args): | |
model = ModelArgs._create_model_from_args(args) | |
dataset = DatasetArgs._create_dataset_from_args(args) | |
if args.num_examples == -1: | |
args.num_examples = len(dataset) | |
preds = [] | |
ground_truth_outputs = [] | |
i = 0 | |
while i < min(args.num_examples, len(dataset)): | |
dataset_batch = dataset[i : min(args.num_examples, i + args.batch_size)] | |
batch_inputs = [] | |
for text_input, ground_truth_output in dataset_batch: | |
attacked_text = textattack.shared.AttackedText(text_input) | |
batch_inputs.append(attacked_text.tokenizer_input) | |
ground_truth_outputs.append(ground_truth_output) | |
batch_preds = model(batch_inputs) | |
if not isinstance(batch_preds, torch.Tensor): | |
batch_preds = torch.Tensor(batch_preds) | |
preds.extend(batch_preds) | |
i += args.batch_size | |
preds = torch.stack(preds).squeeze().cpu() | |
ground_truth_outputs = torch.tensor(ground_truth_outputs).cpu() | |
logger.info(f"Got {len(preds)} predictions.") | |
if preds.ndim == 1: | |
# if preds is just a list of numbers, assume regression for now | |
# TODO integrate with `textattack.metrics` package | |
pearson_correlation, _ = scipy.stats.pearsonr(ground_truth_outputs, preds) | |
spearman_correlation, _ = scipy.stats.spearmanr(ground_truth_outputs, preds) | |
logger.info(f"Pearson correlation = {_cb(pearson_correlation)}") | |
logger.info(f"Spearman correlation = {_cb(spearman_correlation)}") | |
else: | |
guess_labels = preds.argmax(dim=1) | |
successes = (guess_labels == ground_truth_outputs).sum().item() | |
perc_accuracy = successes / len(preds) * 100.0 | |
perc_accuracy = "{:.2f}%".format(perc_accuracy) | |
logger.info(f"Correct {successes}/{len(preds)} ({_cb(perc_accuracy)})") | |
def run(self, args): | |
args = ModelEvalArgs(**vars(args)) | |
textattack.shared.utils.set_seed(args.random_seed) | |
# Default to 'all' if no model chosen. | |
if not (args.model or args.model_from_huggingface or args.model_from_file): | |
for model_name in list(HUGGINGFACE_MODELS.keys()) + list( | |
TEXTATTACK_MODELS.keys() | |
): | |
args.model = model_name | |
self.test_model_on_dataset(args) | |
logger.info("-" * 50) | |
else: | |
self.test_model_on_dataset(args) | |
def register_subcommand(main_parser: ArgumentParser): | |
parser = main_parser.add_parser( | |
"eval", | |
help="evaluate a model with TextAttack", | |
formatter_class=ArgumentDefaultsHelpFormatter, | |
) | |
parser = ModelArgs._add_parser_args(parser) | |
parser = DatasetArgs._add_parser_args(parser) | |
parser.add_argument("--random-seed", default=765, type=int) | |
parser.add_argument( | |
"--batch-size", | |
type=int, | |
default=32, | |
help="The batch size for evaluating the model.", | |
) | |
parser.add_argument( | |
"--num-examples", | |
"-n", | |
type=int, | |
required=False, | |
default=5, | |
help="The number of examples to process, -1 for entire dataset", | |
) | |
parser.add_argument( | |
"--num-examples-offset", | |
"-o", | |
type=int, | |
required=False, | |
default=0, | |
help="The offset to start at in the dataset.", | |
) | |
parser.set_defaults(func=EvalModelCommand()) | |