metric / standard.py
Elron's picture
Upload standard.py with huggingface_hub
a024d9a verified
raw
history blame
10.9 kB
from typing import List
from .card import TaskCard
from .dataclass import Field, InternalField, OptionalField
from .formats import Format, SystemFormat
from .logging_utils import get_logger
from .operator import SourceSequentialOperator, StreamingOperator
from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
from .recipe import Recipe
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .system_prompts import EmptySystemPrompt, SystemPrompt
from .templates import Template
logger = get_logger()
# Used to give meaningful name to recipe steps
class CreateDemosPool(SeparateSplit):
pass
class AddDemosField(SpreadSplit):
pass
class BaseRecipe(Recipe, SourceSequentialOperator):
card: TaskCard
template: Template = None
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
format: Format = Field(default_factory=SystemFormat)
loader_limit: int = None
max_train_instances: int = None
max_validation_instances: int = None
max_test_instances: int = None
train_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
validation_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
demos_pool_size: int = None
num_demos: int = 0
demos_removed_from_data: bool = True
demos_pool_name: str = "demos_pool"
demos_taken_from: str = "train"
demos_field: str = "demos"
sampler: Sampler = None
augmentor: Augmentor = OptionalField(default_factory=NullAugmentor)
steps: List[StreamingOperator] = InternalField(default_factory=list)
def before_process_multi_stream(self):
super().before_process_multi_stream()
if self.sampler: # e.g. when num_demos is 0, the sampler may not be initialized
self.sampler.init_new_random_generator()
def verify(self):
super().verify()
if self.num_demos > 0:
if self.demos_pool_size is None or self.demos_pool_size < 1:
raise ValueError(
"When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
)
if self.demos_pool_size < self.num_demos:
raise ValueError(
f"num_demos (got: {self.num_demos}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
)
if self.loader_limit and self.demos_pool_size > self.loader_limit:
raise ValueError(
f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
)
if self.loader_limit:
if self.max_test_instances and self.max_test_instances > self.loader_limit:
raise ValueError(
f"max_test_instances should not exceed loader_limit ({self.loader_limit}), Got max_test_instances={self.max_test_instances}"
)
if (
self.max_validation_instances
and self.max_validation_instances > self.loader_limit
):
raise ValueError(
f"max_validation_instances should not exceed loader_limit ({self.loader_limit}), Got max_validation_instances={self.max_validation_instances}"
)
if (
self.max_train_instances
and self.max_train_instances > self.loader_limit
):
raise ValueError(
f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}"
)
def prepare_refiners(self):
self.train_refiner.max_instances = self.max_train_instances
self.train_refiner.apply_to_streams = ["train"]
self.steps.append(self.train_refiner)
self.validation_refiner.max_instances = self.max_validation_instances
self.validation_refiner.apply_to_streams = ["validation"]
self.steps.append(self.validation_refiner)
self.test_refiner.max_instances = self.max_test_instances
self.test_refiner.apply_to_streams = ["test"]
self.steps.append(self.test_refiner)
def prepare(self):
self.steps = [
self.card.loader,
AddFields(
fields={
"recipe_metadata": {
"card": self.card,
"template": self.template,
"system_prompt": self.system_prompt,
"format": self.format,
}
}
),
]
if self.loader_limit:
self.card.loader.loader_limit = self.loader_limit
logger.info(f"Loader line limit was set to {self.loader_limit}")
self.steps.append(StreamRefiner(max_instances=self.loader_limit))
if self.card.preprocess_steps is not None:
self.steps.extend(self.card.preprocess_steps)
self.steps.append(self.card.task)
if self.augmentor.augment_task_input:
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
self.steps.append(self.augmentor)
if self.demos_pool_size is not None:
self.steps.append(
CreateDemosPool(
from_split=self.demos_taken_from,
to_split_names=[self.demos_pool_name, self.demos_taken_from],
to_split_sizes=[int(self.demos_pool_size)],
remove_targets_from_source_split=self.demos_removed_from_data,
)
)
if self.num_demos > 0:
if self.sampler is None:
if self.card.sampler is None:
raise ValueError(
"Unexpected None value for card.sampler. "
"To use num_demos > 0, please set a sampler on the TaskCard."
)
self.sampler = self.card.sampler
self.sampler.set_size(self.num_demos)
self.prepare_refiners()
self.steps.append(self.template)
if self.num_demos > 0:
self.steps.append(
AddDemosField(
source_stream=self.demos_pool_name,
target_field=self.demos_field,
sampler=self.sampler,
)
)
self.steps.append(self.system_prompt)
self.steps.append(self.format)
if self.augmentor.augment_model_input:
self.steps.append(self.augmentor)
postprocessors = self.template.get_postprocessors()
self.steps.append(
ToUnitxtGroup(
group="unitxt",
metrics=self.card.task.metrics,
postprocessors=postprocessors,
)
)
class StandardRecipeWithIndexes(BaseRecipe):
template_card_index: int = None
def prepare(self):
assert (
self.template_card_index is None or self.template is None
), f"Specify either template ({self.template}) or template_card_index ({self.template_card_index}) but not both"
assert not (
self.template_card_index is None and self.template is None
), "Specify either template or template_card_index in card"
if self.template_card_index is not None:
try:
self.template = self.card.templates[self.template_card_index]
except Exception as e:
if isinstance(self.card.templates, dict):
options = list(self.card.templates.keys())
else:
options = list(range(0, len(self.card.templates)))
raise ValueError(
f"card_template_index '{self.template_card_index}' is not defined in card. Possible card_template_index options: {options}"
) from e
super().prepare()
class StandardRecipe(StandardRecipeWithIndexes):
"""This class represents a standard recipe for data processing and preparation.
This class can be used to prepare a recipe.
with all necessary steps, refiners and renderers included. It allows to set various
parameters and steps in a sequential manner for preparing the recipe.
Attributes:
card (TaskCard): TaskCard object associated with the recipe.
template (Template, optional): Template object to be used for the recipe.
system_prompt (SystemPrompt, optional): SystemPrompt object to be used for the recipe.
loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
max_train_instances (int, optional): Maximum training instances for the refiner.
validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
max_validation_instances (int, optional): Maximum validation instances for the refiner.
test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe.
max_test_instances (int, optional): Maximum test instances for the refiner.
demos_pool_size (int, optional): Size of the demos pool.
num_demos (int, optional): Number of demos to be used.
demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
demos_field (str, optional): Field name for demos. Default is "demos".
demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True
sampler (Sampler, optional): Sampler object to be used in the recipe.
steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
instruction_card_index (int, optional): Index of instruction card to be used
for preparing the recipe.
template_card_index (int, optional): Index of template card to be used for
preparing the recipe.
Methods:
prepare(): This overridden method is used for preparing the recipe
by arranging all the steps, refiners, and renderers in a sequential manner.
Raises:
AssertionError: If both template and template_card_index are specified at the same time.
"""
pass