|
""" |
|
|
|
Metrics on AttackQueries |
|
--------------------------------------------------------------------- |
|
|
|
""" |
|
|
|
import numpy as np |
|
|
|
from textattack.attack_results import SkippedAttackResult |
|
from textattack.metrics import Metric |
|
|
|
|
|
class AttackQueries(Metric): |
|
def __init__(self): |
|
self.all_metrics = {} |
|
|
|
def calculate(self, results): |
|
"""Calculates all metrics related to number of queries in an attack. |
|
|
|
Args: |
|
results (``AttackResult`` objects): |
|
Attack results for each instance in dataset |
|
""" |
|
|
|
self.results = results |
|
self.num_queries = np.array( |
|
[ |
|
r.num_queries |
|
for r in self.results |
|
if not isinstance(r, SkippedAttackResult) |
|
] |
|
) |
|
self.all_metrics["avg_num_queries"] = self.avg_num_queries() |
|
|
|
return self.all_metrics |
|
|
|
def avg_num_queries(self): |
|
avg_num_queries = self.num_queries.mean() |
|
avg_num_queries = round(avg_num_queries, 2) |
|
return avg_num_queries |
|
|