File size: 2,365 Bytes
c6e9c8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import List
from .card import TaskCard
from .dataclass import InternalField
from .formats import ICLFormat
from .instructions import Instruction
from .operator import SourceSequntialOperator, StreamingOperator
from .recipe import Recipe
from .renderers import StandardRenderer
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .templates import Template
class StandardRecipe(Recipe, SourceSequntialOperator):
card: TaskCard
template: Template
instruction: Instruction = None
format: ICLFormat = None
demos_pool_size: int = None
num_demos: int = None
demos_pool_name: str = "demos_pool"
demos_taken_from: str = "train"
demos_field: str = "demos"
sampler: Sampler = None
steps: List[StreamingOperator] = InternalField(default_factory=list)
def prepare(self):
self.steps = [
self.card.loader,
]
if self.card.preprocess_steps is not None:
self.steps.extend(self.card.preprocess_steps)
self.steps.append(self.card.task)
if self.demos_pool_size is not None:
self.steps.append(
SeparateSplit(
from_split=self.demos_taken_from,
to_split_names=[self.demos_pool_name, self.demos_taken_from],
to_split_sizes=[int(self.demos_pool_size)],
)
)
if self.num_demos is not None:
sampler = self.card.sampler
if self.sampler is not None:
sampler = self.sampler
sampler.set_size(self.num_demos)
self.steps.append(
SpreadSplit(
source_stream=self.demos_pool_name,
target_field=self.demos_field,
sampler=sampler,
)
)
render = StandardRenderer(
instruction=self.instruction,
template=self.template,
format=self.format,
demos_field=self.demos_field,
)
self.steps.append(render)
postprocessors = render.get_postprocessors()
self.steps.append(
ToUnitxtGroup(
group="unitxt",
metrics=self.card.task.metrics,
postprocessors=postprocessors,
)
)
|