""" Metrics on AttackQueries --------------------------------------------------------------------- """ import numpy as np from textattack.attack_results import SkippedAttackResult from textattack.metrics import Metric class AttackQueries(Metric): def __init__(self): self.all_metrics = {} def calculate(self, results): """Calculates all metrics related to number of queries in an attack. Args: results (``AttackResult`` objects): Attack results for each instance in dataset """ self.results = results self.num_queries = np.array( [ r.num_queries for r in self.results if not isinstance(r, SkippedAttackResult) ] ) self.all_metrics["avg_num_queries"] = self.avg_num_queries() return self.all_metrics def avg_num_queries(self): avg_num_queries = self.num_queries.mean() avg_num_queries = round(avg_num_queries, 2) return avg_num_queries