from abc import abstractmethod from typing import Dict, Union from .dataclass import NonPositionalField from .formats import Format from .fusion import FixedFusion, WeightedFusion from .operator import SourceOperator from .standard import StandardRecipe from .stream import MultiStream from .system_prompts import SystemPrompt class BaseBenchmark(SourceOperator): format: Format = NonPositionalField(default=None) num_demos: int = NonPositionalField(default=None) system_prompt: SystemPrompt = NonPositionalField(default=None) loader_limit: int = NonPositionalField(default=None) @abstractmethod def reset(self): pass class Benchmark(BaseBenchmark): subsets: Dict[str, Union[StandardRecipe, BaseBenchmark]] max_total_samples: int = None max_samples_per_subset: int = None def verify(self): super().verify() if ( self.max_total_samples is not None and self.max_samples_per_subset is not None ): raise ValueError("Set either max_total_samples or max_samples_per_subset") def prepare_args(self): self.subsets = dict(self.subsets) def reset(self): if ( self.format is not None or self.num_demos is not None or self.system_prompt is not None or self.loader_limit is not None ): for subset in self.subsets.values(): if self.num_demos is not None: subset.num_demos = self.num_demos if self.format is not None: subset.format = self.format if self.system_prompt is not None: subset.system_prompt = self.system_prompt if self.loader_limit is not None: subset.loader_limit = self.loader_limit subset.reset() def prepare(self): super().prepare() self.reset() def process( self, ) -> MultiStream: if self.max_total_samples is None: operator = FixedFusion( subsets=self.subsets, max_instances_per_subset=self.max_samples_per_subset, ) else: operator = WeightedFusion( subsets=self.subsets, max_total_samples=self.max_total_samples ) return operator()