|
""" |
|
|
|
Metrics on perturbed words |
|
--------------------------------------------------------------------- |
|
|
|
""" |
|
|
|
import numpy as np |
|
|
|
from textattack.attack_results import FailedAttackResult, SkippedAttackResult |
|
from textattack.metrics import Metric |
|
|
|
|
|
class WordsPerturbed(Metric): |
|
def __init__(self): |
|
self.total_attacks = 0 |
|
self.all_num_words = None |
|
self.perturbed_word_percentages = None |
|
self.num_words_changed_until_success = 0 |
|
self.all_metrics = {} |
|
|
|
def calculate(self, results): |
|
"""Calculates all metrics related to perturbed words in an attack. |
|
|
|
Args: |
|
results (``AttackResult`` objects): |
|
Attack results for each instance in dataset |
|
""" |
|
|
|
self.results = results |
|
self.total_attacks = len(self.results) |
|
self.all_num_words = np.zeros(len(self.results)) |
|
self.perturbed_word_percentages = np.zeros(len(self.results)) |
|
self.num_words_changed_until_success = np.zeros(2**16) |
|
self.max_words_changed = 0 |
|
|
|
for i, result in enumerate(self.results): |
|
self.all_num_words[i] = len(result.original_result.attacked_text.words) |
|
|
|
if isinstance(result, FailedAttackResult) or isinstance( |
|
result, SkippedAttackResult |
|
): |
|
continue |
|
|
|
num_words_changed = len( |
|
result.original_result.attacked_text.all_words_diff( |
|
result.perturbed_result.attacked_text |
|
) |
|
) |
|
self.num_words_changed_until_success[num_words_changed - 1] += 1 |
|
self.max_words_changed = max( |
|
self.max_words_changed or num_words_changed, num_words_changed |
|
) |
|
if len(result.original_result.attacked_text.words) > 0: |
|
perturbed_word_percentage = ( |
|
num_words_changed |
|
* 100.0 |
|
/ len(result.original_result.attacked_text.words) |
|
) |
|
else: |
|
perturbed_word_percentage = 0 |
|
|
|
self.perturbed_word_percentages[i] = perturbed_word_percentage |
|
|
|
self.all_metrics["avg_word_perturbed"] = self.avg_number_word_perturbed_num() |
|
self.all_metrics["avg_word_perturbed_perc"] = self.avg_perturbation_perc() |
|
self.all_metrics["max_words_changed"] = self.max_words_changed |
|
self.all_metrics[ |
|
"num_words_changed_until_success" |
|
] = self.num_words_changed_until_success |
|
|
|
return self.all_metrics |
|
|
|
def avg_number_word_perturbed_num(self): |
|
average_num_words = self.all_num_words.mean() |
|
average_num_words = round(average_num_words, 2) |
|
return average_num_words |
|
|
|
def avg_perturbation_perc(self): |
|
self.perturbed_word_percentages = self.perturbed_word_percentages[ |
|
self.perturbed_word_percentages > 0 |
|
] |
|
average_perc_words_perturbed = self.perturbed_word_percentages.mean() |
|
average_perc_words_perturbed = round(average_perc_words_perturbed, 2) |
|
return average_perc_words_perturbed |
|
|