# Copyright (c) Alibaba, Inc. and its affiliates. import datetime as dt import os from dataclasses import dataclass, field from typing import Dict, List, Literal, Optional, Union from swift.utils import get_logger from .base_args import to_abspath from .deploy_args import DeployArguments logger = get_logger() @dataclass class EvalArguments(DeployArguments): """ EvalArguments is a dataclass that extends DeployArguments and is used to define the arguments required for evaluating a model. Args: eval_dataset (List[str]): List of evaluation datasets. Default is an empty list. eval_limit (Optional[str]): Limit number of each evaluation dataset. Default is None. local_dataset(bool): Download extra dataset from opencompass, default False. eval_generation_config (Optional[Union[Dict, str]]): The generation config for evaluation. Default is None. eval_output_dir (str): The eval output dir. temperature (float): The temperature. verbose (bool): Output verbose information. eval_url (str): The extra eval url, use this as --model. extra_eval_args (Optional[Union[Dict, str]]): Additional evaluation arguments. Default is an empty dict. """ eval_dataset: List[str] = field(default_factory=list) eval_limit: Optional[int] = None dataset_args: Optional[Union[Dict, str]] = None eval_generation_config: Optional[Union[Dict, str]] = field(default_factory=dict) eval_output_dir: str = 'eval_output' eval_backend: Literal['Native', 'OpenCompass', 'VLMEvalKit'] = 'Native' local_dataset: bool = False temperature: Optional[float] = 0. verbose: bool = False eval_num_proc: int = 16 extra_eval_args: Optional[Union[Dict, str]] = field(default_factory=dict) # If eval_url is set, ms-swift will not perform deployment operations and # will directly use the URL for evaluation. eval_url: Optional[str] = None def _init_eval_url(self): # [compat] if self.eval_url and 'chat/completions' in self.eval_url: self.eval_url = self.eval_url.split('/chat/completions', 1)[0] def __post_init__(self): super().__post_init__() self._init_eval_url() self._init_eval_dataset() self.dataset_args = self.parse_to_dict(self.dataset_args) self.eval_generation_config = self.parse_to_dict(self.eval_generation_config) self.eval_output_dir = to_abspath(self.eval_output_dir) logger.info(f'eval_output_dir: {self.eval_output_dir}') @staticmethod def list_eval_dataset(eval_backend=None): from evalscope.constants import EvalBackend from evalscope.benchmarks.benchmark import BENCHMARK_MAPPINGS from evalscope.backend.opencompass import OpenCompassBackendManager res = { EvalBackend.NATIVE: list(BENCHMARK_MAPPINGS.keys()), EvalBackend.OPEN_COMPASS: OpenCompassBackendManager.list_datasets(), } try: from evalscope.backend.vlm_eval_kit import VLMEvalKitBackendManager vlm_datasets = VLMEvalKitBackendManager.list_supported_datasets() res[EvalBackend.VLM_EVAL_KIT] = vlm_datasets except ImportError: # fix cv2 import error if eval_backend == 'VLMEvalKit': raise return res def _init_eval_dataset(self): if isinstance(self.eval_dataset, str): self.eval_dataset = [self.eval_dataset] all_eval_dataset = self.list_eval_dataset(self.eval_backend) dataset_mapping = {dataset.lower(): dataset for dataset in all_eval_dataset[self.eval_backend]} valid_dataset = [] for dataset in self.eval_dataset: if dataset.lower() not in dataset_mapping: raise ValueError( f'eval_dataset: {dataset} is not supported.\n' f'eval_backend: {self.eval_backend} supported datasets: {all_eval_dataset[self.eval_backend]}') valid_dataset.append(dataset_mapping[dataset.lower()]) self.eval_dataset = valid_dataset logger.info(f'eval_backend: {self.eval_backend}') logger.info(f'eval_dataset: {self.eval_dataset}') def _init_result_path(self, folder_name: str) -> None: self.time = dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') result_dir = self.ckpt_dir or f'result/{self.model_suffix}' os.makedirs(result_dir, exist_ok=True) self.result_jsonl = to_abspath(os.path.join(result_dir, 'eval_result.jsonl')) if not self.eval_url: super()._init_result_path('eval_result') def _init_torch_dtype(self) -> None: if self.eval_url: self.model_dir = self.eval_output_dir return super()._init_torch_dtype()