anonymous8/RPD-Demo
initial commit
4943752
raw
history blame
1.05 kB
"""
Metrics on AttackQueries
---------------------------------------------------------------------
"""
import numpy as np
from textattack.attack_results import SkippedAttackResult
from textattack.metrics import Metric
class AttackQueries(Metric):
def __init__(self):
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to number of queries in an attack.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
self.all_metrics["avg_num_queries"] = self.avg_num_queries()
return self.all_metrics
def avg_num_queries(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries