from typing import List from .card import TaskCard from .dataclass import Field, InternalField, NonPositionalField, OptionalField from .formats import Format, SystemFormat from .logging_utils import get_logger from .operator import SourceSequentialOperator, StreamingOperator from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner from .recipe import Recipe from .schema import ToUnitxtGroup from .splitters import Sampler, SeparateSplit, SpreadSplit from .system_prompts import EmptySystemPrompt, SystemPrompt from .templates import Template logger = get_logger() # Used to give meaningful name to recipe steps class CreateDemosPool(SeparateSplit): pass class AddDemosField(SpreadSplit): pass class BaseRecipe(Recipe, SourceSequentialOperator): card: TaskCard template: Template = None system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt) format: Format = Field(default_factory=SystemFormat) metrics: List[str] = NonPositionalField(default=None) postprocessors: List[str] = NonPositionalField(default=None) loader_limit: int = None max_train_instances: int = None max_validation_instances: int = None max_test_instances: int = None train_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner) validation_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner) test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner) demos_pool_size: int = None num_demos: int = 0 demos_removed_from_data: bool = True demos_pool_name: str = "demos_pool" demos_taken_from: str = "train" demos_field: str = "demos" sampler: Sampler = None augmentor: Augmentor = OptionalField(default_factory=NullAugmentor) steps: List[StreamingOperator] = InternalField(default_factory=list) def before_process_multi_stream(self): super().before_process_multi_stream() if self.sampler: # e.g. when num_demos is 0, the sampler may not be initialized self.sampler.init_new_random_generator() def verify(self): super().verify() if self.num_demos > 0: if self.demos_pool_size is None or self.demos_pool_size < 1: raise ValueError( "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers." ) if self.demos_pool_size < self.num_demos: raise ValueError( f"num_demos (got: {self.num_demos}) should not exceed demos_pool_size (got: {self.demos_pool_size})" ) if self.loader_limit and self.demos_pool_size > self.loader_limit: raise ValueError( f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}" ) if self.loader_limit: if self.max_test_instances and self.max_test_instances > self.loader_limit: raise ValueError( f"max_test_instances should not exceed loader_limit ({self.loader_limit}), Got max_test_instances={self.max_test_instances}" ) if ( self.max_validation_instances and self.max_validation_instances > self.loader_limit ): raise ValueError( f"max_validation_instances should not exceed loader_limit ({self.loader_limit}), Got max_validation_instances={self.max_validation_instances}" ) if ( self.max_train_instances and self.max_train_instances > self.loader_limit ): raise ValueError( f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}" ) def prepare_refiners(self): self.train_refiner.max_instances = self.max_train_instances self.train_refiner.apply_to_streams = ["train"] self.steps.append(self.train_refiner) self.validation_refiner.max_instances = self.max_validation_instances self.validation_refiner.apply_to_streams = ["validation"] self.steps.append(self.validation_refiner) self.test_refiner.max_instances = self.max_test_instances self.test_refiner.apply_to_streams = ["test"] self.steps.append(self.test_refiner) def prepare_metrics_and_postprocessors(self): if self.postprocessors is None: postprocessors = self.template.get_postprocessors() else: postprocessors = self.postprocessors if self.metrics is None: metrics = self.card.task.metrics else: metrics = self.metrics return metrics, postprocessors def prepare(self): self.steps = [ self.card.loader, AddFields( fields={ "recipe_metadata": { "card": self.card, "template": self.template, "system_prompt": self.system_prompt, "format": self.format, } } ), ] if self.loader_limit: self.card.loader.loader_limit = self.loader_limit logger.info(f"Loader line limit was set to {self.loader_limit}") self.steps.append(StreamRefiner(max_instances=self.loader_limit)) if self.card.preprocess_steps is not None: self.steps.extend(self.card.preprocess_steps) self.steps.append(self.card.task) if self.augmentor.augment_task_input: self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs) self.steps.append(self.augmentor) if self.demos_pool_size is not None: self.steps.append( CreateDemosPool( from_split=self.demos_taken_from, to_split_names=[self.demos_pool_name, self.demos_taken_from], to_split_sizes=[int(self.demos_pool_size)], remove_targets_from_source_split=self.demos_removed_from_data, ) ) if self.num_demos > 0: if self.sampler is None: if self.card.sampler is None: raise ValueError( "Unexpected None value for card.sampler. " "To use num_demos > 0, please set a sampler on the TaskCard." ) self.sampler = self.card.sampler self.sampler.set_size(self.num_demos) self.prepare_refiners() self.steps.append(self.template) if self.num_demos > 0: self.steps.append( AddDemosField( source_stream=self.demos_pool_name, target_field=self.demos_field, sampler=self.sampler, ) ) self.steps.append(self.system_prompt) self.steps.append(self.format) if self.augmentor.augment_model_input: self.steps.append(self.augmentor) metrics, postprocessors = self.prepare_metrics_and_postprocessors() self.steps.append( ToUnitxtGroup( group="unitxt", metrics=metrics, postprocessors=postprocessors, ) ) class StandardRecipeWithIndexes(BaseRecipe): template_card_index: int = None def prepare(self): assert ( self.template_card_index is None or self.template is None ), f"Specify either template ({self.template}) or template_card_index ({self.template_card_index}) but not both" assert not ( self.template_card_index is None and self.template is None ), "Specify either template or template_card_index in card" if self.template_card_index is not None: try: self.template = self.card.templates[self.template_card_index] except Exception as e: if isinstance(self.card.templates, dict): options = list(self.card.templates.keys()) else: options = list(range(0, len(self.card.templates))) raise ValueError( f"card_template_index '{self.template_card_index}' is not defined in card. Possible card_template_index options: {options}" ) from e super().prepare() class StandardRecipe(StandardRecipeWithIndexes): """This class represents a standard recipe for data processing and preparation. This class can be used to prepare a recipe. with all necessary steps, refiners and renderers included. It allows to set various parameters and steps in a sequential manner for preparing the recipe. Attributes: card (TaskCard): TaskCard object associated with the recipe. template (Template, optional): Template object to be used for the recipe. system_prompt (SystemPrompt, optional): SystemPrompt object to be used for the recipe. loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets) format (SystemFormat, optional): SystemFormat object to be used for the recipe. metrics (List[str]): list of catalog metrics to use with this recipe. postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here) train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe. max_train_instances (int, optional): Maximum training instances for the refiner. validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe. max_validation_instances (int, optional): Maximum validation instances for the refiner. test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe. max_test_instances (int, optional): Maximum test instances for the refiner. demos_pool_size (int, optional): Size of the demos pool. num_demos (int, optional): Number of demos to be used. demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool". demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train". demos_field (str, optional): Field name for demos. Default is "demos". demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True sampler (Sampler, optional): Sampler object to be used in the recipe. steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe. augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text instruction_card_index (int, optional): Index of instruction card to be used for preparing the recipe. template_card_index (int, optional): Index of template card to be used for preparing the recipe. Methods: prepare(): This overridden method is used for preparing the recipe by arranging all the steps, refiners, and renderers in a sequential manner. Raises: AssertionError: If both template and template_card_index are specified at the same time. """ pass