File size: 2,368 Bytes
fe70438
d08fbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe70438
 
 
 
d08fbc6
 
 
 
 
 
 
 
fe70438
d08fbc6
 
 
 
 
 
88c61d3
 
 
fe70438
 
 
 
 
 
 
d08fbc6
 
 
 
 
 
 
 
 
fe70438
 
 
 
 
 
 
d08fbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from abc import abstractmethod
from typing import Dict, Union

from .dataclass import NonPositionalField
from .formats import Format
from .fusion import FixedFusion, WeightedFusion
from .operator import SourceOperator
from .standard import StandardRecipe
from .stream import MultiStream
from .system_prompts import SystemPrompt


class BaseBenchmark(SourceOperator):
    format: Format = NonPositionalField(default=None)
    num_demos: int = NonPositionalField(default=None)
    system_prompt: SystemPrompt = NonPositionalField(default=None)
    loader_limit: int = NonPositionalField(default=None)

    @abstractmethod
    def reset(self):
        pass


class Benchmark(BaseBenchmark):
    subsets: Dict[str, Union[StandardRecipe, BaseBenchmark]]

    max_total_samples: int = None
    max_samples_per_subset: int = None

    def verify(self):
        super().verify()
        if (
            self.max_total_samples is not None
            and self.max_samples_per_subset is not None
        ):
            raise ValueError("Set either max_total_samples or max_samples_per_subset")

    def prepare_args(self):
        self.subsets = dict(self.subsets)

    def reset(self):
        if (
            self.format is not None
            or self.num_demos is not None
            or self.system_prompt is not None
            or self.loader_limit is not None
        ):
            for subset in self.subsets.values():
                if self.num_demos is not None:
                    subset.num_demos = self.num_demos
                if self.format is not None:
                    subset.format = self.format
                if self.system_prompt is not None:
                    subset.system_prompt = self.system_prompt
                if self.loader_limit is not None:
                    subset.loader_limit = self.loader_limit

                subset.reset()

    def prepare(self):
        super().prepare()

        self.reset()

    def process(
        self,
    ) -> MultiStream:
        if self.max_total_samples is None:
            operator = FixedFusion(
                subsets=self.subsets,
                max_instances_per_subset=self.max_samples_per_subset,
            )
        else:
            operator = WeightedFusion(
                subsets=self.subsets, max_total_samples=self.max_total_samples
            )

        return operator()