import itertools from abc import abstractmethod from difflib import get_close_matches from typing import Dict, List, Optional from .artifact import Artifact from .dict_utils import dict_get from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator from .random_utils import new_random_generator from .split_utils import ( parse_random_mix_string, parse_slices_string, random_mix_streams, rename_split, slice_streams, ) from .stream import EmptyStreamError, FaultyStreamError, MultiStream from .type_utils import isoftype from .utils import deepcopy class Splitter(MultiStreamOperator): pass class RenameSplits(Splitter): mapper: Dict[str, str] def process(self, multi_stream: MultiStream) -> MultiStream: generators = rename_split(multi_stream, self.mapper) return MultiStream(generators) class SplitRandomMix(Splitter): """Splits a multistream into new streams (splits), whose names, source input stream, and amount of instances, are specified by arg 'mix'. The keys of arg 'mix', are the names of the new streams, the values are of the form: 'name-of-source-stream[percentage-of-source-stream]' Each input instance, of any input stream, is selected exactly once for inclusion in any of the output streams. Examples: When processing a multistream made of two streams whose names are 'train' and 'test', by SplitRandomMix(mix = { "train": "train[99%]", "validation": "train[1%]", "test": "test" }) the output is a multistream, whose three streams are named 'train', 'validation', and 'test'. Output stream 'train' is made of randomly selected 99% of the instances of input stream 'train', output stream 'validation' is made of the remaining 1% instances of input 'train', and output stream 'test' is made of the whole of input stream 'test'. When processing the above input multistream by SplitRandomMix(mix = { "train": "train[50%]+test[0.1]", "validation": "train[50%]+test[0.2]", "test": "test[0.7]" }) the output is a multistream, whose three streams are named 'train', 'validation', and 'test'. Output stream 'train' is made of randomly selected 50% of the instances of input stream 'train' + randomly selected 0.1 (i.e., 10%) of the instances of input stream 'test'. Output stream 'validation' is made of the remaining 50% instances of input 'train'+ randomly selected 0.2 (i.e., 20%) of the original instances of input 'test', that were not selected for output 'train', and output stream 'test' is made of the remaining instances of input 'test'. """ mix: Dict[str, str] def process(self, multi_stream: MultiStream) -> MultiStream: mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()} generators = random_mix_streams(multi_stream, mapping) return MultiStream.from_generators(generators) class SeparateSplit(Splitter): """Separates a split (e.g. train) into several splits (e.g. train1, train2). sizes must indicate the size of every split except the last. If no size is give for the last split, it includes all the examples not allocated to any split. """ from_split: str to_split_names: List[str] to_split_sizes: List[int] remove_targets_from_source_split: bool = True def verify(self): assert ( len(self.to_split_names) == len(self.to_split_sizes) or len(self.to_split_names) == len(self.to_split_sizes) + 1 ), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}" return super().verify() def process(self, multi_stream: MultiStream) -> MultiStream: mapping = { key: {key: [(None, None)]} for key in multi_stream.keys() if not self.remove_targets_from_source_split or key != self.from_split } so_far = 0 for name, size in itertools.zip_longest( self.to_split_names, self.to_split_sizes ): if self.remove_targets_from_source_split or name != self.from_split: mapping[name] = {self.from_split: [(so_far, size)]} if size: so_far += size generators = slice_streams(multi_stream, mapping) return MultiStream.from_generators(generators) class SliceSplit(Splitter): slices: Dict[str, str] def process(self, multi_stream: MultiStream) -> MultiStream: mapping = {k: parse_slices_string(v) for k, v in self.slices.items()} generators = slice_streams(multi_stream, mapping) return MultiStream.from_generators(generators) def get_random_generator_based_on_instance(instance): return new_random_generator(sub_seed={**instance["input_fields"]}) class Sampler(Artifact): @abstractmethod def sample( self, sample_size: int, instances_pool: List[Dict[str, object]], instance: Dict[str, object], ) -> List[Dict[str, object]]: pass def filter_source_by_instance( self, instances_pool: List[Dict[str, object]], instance: Dict[str, object] ) -> List[Dict[str, object]]: if "input_fields" not in instance: raise ValueError(f"'input_fields' field is missing from '{instance}'.") try: return [ item for item in instances_pool if item["input_fields"] != instance["input_fields"] ] except Exception as e: raise e class RandomSampler(Sampler): """Selects a random sample of instances.""" def sample( self, sample_size, instances_pool: List[Dict[str, object]], instance: Optional[Dict[str, object]], ) -> List[Dict[str, object]]: instances_pool = list(instances_pool) random_generator = get_random_generator_based_on_instance(instance) return random_generator.sample(instances_pool, sample_size) class FixedIndicesSampler(Sampler): """Selects a fix set of samples based on a list of indices.""" indices: List[int] def verify(self): assert isoftype( self.indices, List[int] ), f"'indices' of {self.__class__.__name__} must be List[int]. Value {self.indices} is of type {type(self.indices)}" super().verify() def sample( self, sample_size, instances_pool: List[Dict[str, object]], instance: Optional[Dict[str, object]], ) -> List[Dict[str, object]]: num_instances = len(instances_pool) instances = [] for index in self.indices[0:sample_size]: if index >= num_instances: raise ValueError( f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})" ) instances.append(instances_pool[index]) return instances class CloseTextSampler(Sampler): """Selects the samples of instances which are the closest textual match to the given instance. Comparison is done based on a given field in the instance. """ field: str def sample( self, sample_size: int, instances_pool: List[Dict[str, object]], instance: Dict[str, object], ) -> List[Dict[str, object]]: field = f"input_fields/{self.field}" value = dict_get(instance, field) instances_pool = list(instances_pool) # Get 'sample_size' closest matchest texts based on field options = [] for instance_in_pool in instances_pool: options.append(dict_get(instance_in_pool, field)) closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0) # Randmly select 'sample_size' instances that are from the closest matches text # (There may be multiple instance with same text in the given field, and the order returned is # is also randomized ) instances_pool = [ instance_in_pool for instance_in_pool in instances_pool if dict_get(instance_in_pool, field) in closest_matches ] random_generator = get_random_generator_based_on_instance(instance) return random_generator.sample(instances_pool, sample_size) class DiverseLabelsSampler(Sampler): """Selects a balanced sample of instances based on an output field. (used for selecting demonstrations in-context learning) The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow']. The balancing is done such that each value or combination of values appears as equals as possible in the samples. The `choices` param is required and determines which values should be considered. Example: If choices is ['dog,'cat'] , then the following combinations will be considered. [''] ['cat'] ['dog'] ['dog','cat'] If the instance contains a value not in the 'choice' param, it is ignored. For example, if choices is ['dog,'cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored then the instance is considered as ['dog','cat']. Args: sample_size - number of samples to extract choices - name of input field that contains the list of values to balance on labels - name of output field with labels that must be balanced """ choices: str = "choices" labels: str = "labels" include_empty_label: bool = True def prepare(self): super().prepare() self.labels_cache = None def exemplar_repr(self, exemplar): if "input_fields" not in exemplar: raise ValueError(f"'input_fields' field is missing from '{exemplar}'.") inputs = exemplar["input_fields"] if self.choices not in inputs: raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.") choices = inputs[self.choices] if not isinstance(choices, list): if isinstance(choices, str): choices = [choices] else: raise ValueError( f"Unexpected input choices value '{choices}'. Expected a list or a string." ) if "reference_fields" not in exemplar: raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.") outputs = exemplar["reference_fields"] if self.labels not in outputs: raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.") exemplar_outputs = exemplar["reference_fields"][self.labels] if not isinstance(exemplar_outputs, list): raise ValueError( f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list." ) return str([choice for choice in choices if choice in exemplar_outputs]) def divide_by_repr(self, exemplars_pool): labels = {} for exemplar in exemplars_pool: label_repr = self.exemplar_repr(exemplar) if label_repr == "[]" and not self.include_empty_label: continue if label_repr not in labels: labels[label_repr] = [] labels[label_repr].append(exemplar) return labels def sample( self, sample_size: int, instances_pool: List[Dict[str, object]], instance: Optional[Dict[str, object]], ) -> List[Dict[str, object]]: if self.labels_cache is None: self.labels_cache = self.divide_by_repr(instances_pool) all_labels = list(self.labels_cache.keys()) random_generator = get_random_generator_based_on_instance(instance) random_generator.shuffle(all_labels) from collections import Counter if sample_size > len(instances_pool): raise ValueError( f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}" ) total_allocated = 0 allocations = Counter() while total_allocated < sample_size: for label in all_labels: if total_allocated < sample_size: if len(self.labels_cache[label]) - allocations[label] > 0: allocations[label] += 1 total_allocated += 1 else: break result = [] for label, allocation in allocations.items(): sample = random_generator.sample(self.labels_cache[label], allocation) result.extend(sample) random_generator.shuffle(result) return result class Sample(InstanceOperatorWithMultiStreamAccess): from_stream: str to_field: str sampler: Sampler def prepare(self): self.local_cache = None self.sampler.prepare() @abstractmethod def get_sample_size(self, instance) -> int: pass def process( self, instance: Dict[str, object], multi_stream: MultiStream ) -> Dict[str, object]: sample_size = self.get_sample_size(instance) try: if self.local_cache is None: self.local_cache = deepcopy(list(multi_stream[self.from_stream])) source_stream = self.local_cache source_stream = self.sampler.filter_source_by_instance( source_stream, instance ) if len(source_stream) < sample_size: raise ValueError( f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}." ) sampled_instances = self.sampler.sample( sample_size=sample_size, instances_pool=source_stream, instance=instance ) instance[self.to_field] = sampled_instances return instance except FaultyStreamError as e: raise EmptyStreamError( f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}" ) from e class ConstantSizeSample(Sample): sample_size: int def get_sample_size(self, instance) -> int: return self.sample_size class RandomSizeSample(Sample): sample_sizes: List[int] def get_sample_size(self, instance) -> int: random_generator = get_random_generator_based_on_instance(instance) return random_generator.choice(self.sample_sizes)