|
""" |
|
|
|
USEMetric class: |
|
------------------------------------------------------- |
|
Class for calculating USE similarity on AttackResults |
|
|
|
""" |
|
|
|
from textattack.attack_results import FailedAttackResult, SkippedAttackResult |
|
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder |
|
from textattack.metrics import Metric |
|
|
|
|
|
class USEMetric(Metric): |
|
def __init__(self, **kwargs): |
|
self.use_obj = UniversalSentenceEncoder() |
|
self.use_obj.model = UniversalSentenceEncoder() |
|
self.original_candidates = [] |
|
self.successful_candidates = [] |
|
self.all_metrics = {} |
|
|
|
def calculate(self, results): |
|
"""Calculates average USE similarity on all successfull attacks. |
|
|
|
Args: |
|
results (``AttackResult`` objects): |
|
Attack results for each instance in dataset |
|
|
|
Example:: |
|
|
|
|
|
>> import textattack |
|
>> import transformers |
|
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
|
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
|
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) |
|
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) |
|
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") |
|
>> attack_args = textattack.AttackArgs( |
|
num_examples=1, |
|
log_to_csv="log.csv", |
|
checkpoint_interval=5, |
|
checkpoint_dir="checkpoints", |
|
disable_stdout=True |
|
) |
|
>> attacker = textattack.Attacker(attack, dataset, attack_args) |
|
>> results = attacker.attack_dataset() |
|
>> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results) |
|
""" |
|
|
|
self.results = results |
|
|
|
for i, result in enumerate(self.results): |
|
if isinstance(result, FailedAttackResult): |
|
continue |
|
elif isinstance(result, SkippedAttackResult): |
|
continue |
|
else: |
|
self.original_candidates.append(result.original_result.attacked_text) |
|
self.successful_candidates.append(result.perturbed_result.attacked_text) |
|
|
|
use_scores = [] |
|
for c in range(len(self.original_candidates)): |
|
use_scores.append( |
|
self.use_obj._sim_score( |
|
self.original_candidates[c], self.successful_candidates[c] |
|
).item() |
|
) |
|
|
|
self.all_metrics["avg_attack_use_score"] = round( |
|
sum(use_scores) / len(use_scores), 2 |
|
) |
|
|
|
return self.all_metrics |
|
|