from functools import lru_cache from typing import Any, Dict, List, Union from datasets import DatasetDict from .artifact import fetch_artifact from .dataset_utils import get_dataset_artifact from .logging_utils import get_logger from .metric_utils import _compute from .operator import SourceOperator logger = get_logger() def load(source: Union[SourceOperator, str]) -> DatasetDict: assert isinstance( source, (SourceOperator, str) ), "source must be a SourceOperator or a string" if isinstance(source, str): source, _ = fetch_artifact(source) return source().to_dataset() def load_dataset(dataset_query: str) -> DatasetDict: dataset_query = dataset_query.replace("sys_prompt", "instruction") dataset_stream = get_dataset_artifact(dataset_query) return dataset_stream().to_dataset() def evaluate(predictions, data) -> List[Dict[str, Any]]: return _compute(predictions=predictions, references=data) @lru_cache def _get_produce_with_cache(recipe_query): return get_dataset_artifact(recipe_query).produce def produce(instance_or_instances, recipe_query): is_list = isinstance(instance_or_instances, list) if not is_list: instance_or_instances = [instance_or_instances] result = _get_produce_with_cache(recipe_query)(instance_or_instances) if not is_list: result = result[0] return result