File size: 1,054 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
"""
Metrics on AttackQueries
---------------------------------------------------------------------
"""
import numpy as np
from textattack.attack_results import SkippedAttackResult
from textattack.metrics import Metric
class AttackQueries(Metric):
def __init__(self):
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to number of queries in an attack.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
self.all_metrics["avg_num_queries"] = self.avg_num_queries()
return self.all_metrics
def avg_num_queries(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries
|