|
|
|
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Literal, Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
from swift.utils import get_logger, init_process_group, set_default_ddp_config |
|
|
from .base_args import BaseArguments, to_abspath |
|
|
from .merge_args import MergeArguments |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExportArguments(MergeArguments, BaseArguments): |
|
|
""" |
|
|
ExportArguments is a dataclass that inherits from BaseArguments and MergeArguments. |
|
|
|
|
|
Args: |
|
|
output_dir (Optional[str]): Directory where the output will be saved. |
|
|
quant_n_samples (int): Number of samples for quantization. |
|
|
max_length (int): Sequence length for quantization. |
|
|
quant_batch_size (int): Batch size for quantization. |
|
|
to_ollama (bool): Flag to indicate export model to ollama format. |
|
|
push_to_hub (bool): Flag to indicate if the output should be pushed to the model hub. |
|
|
hub_model_id (Optional[str]): Model ID for the hub. |
|
|
hub_private_repo (bool): Flag to indicate if the hub repository is private. |
|
|
commit_message (str): Commit message for pushing to the hub. |
|
|
to_peft_format (bool): Flag to indicate if the output should be in PEFT format. |
|
|
This argument is useless for now. |
|
|
""" |
|
|
output_dir: Optional[str] = None |
|
|
|
|
|
|
|
|
quant_method: Literal['awq', 'gptq', 'bnb'] = None |
|
|
quant_n_samples: int = 256 |
|
|
max_length: int = 2048 |
|
|
quant_batch_size: int = 1 |
|
|
group_size: int = 128 |
|
|
|
|
|
|
|
|
to_ollama: bool = False |
|
|
|
|
|
|
|
|
to_mcore: bool = False |
|
|
to_hf: bool = False |
|
|
mcore_model: Optional[str] = None |
|
|
thread_count: Optional[int] = None |
|
|
test_convert_precision: bool = False |
|
|
|
|
|
|
|
|
push_to_hub: bool = False |
|
|
|
|
|
hub_model_id: Optional[str] = None |
|
|
hub_private_repo: bool = False |
|
|
commit_message: str = 'update files' |
|
|
|
|
|
to_peft_format: bool = False |
|
|
exist_ok: bool = False |
|
|
|
|
|
def _init_output_dir(self): |
|
|
if self.output_dir is None: |
|
|
ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' |
|
|
ckpt_dir, ckpt_name = os.path.split(ckpt_dir) |
|
|
if self.to_peft_format: |
|
|
suffix = 'peft' |
|
|
elif self.quant_method: |
|
|
suffix = f'{self.quant_method}-int{self.quant_bits}' |
|
|
elif self.to_ollama: |
|
|
suffix = 'ollama' |
|
|
elif self.merge_lora: |
|
|
suffix = 'merged' |
|
|
elif self.to_mcore: |
|
|
suffix = 'mcore' |
|
|
elif self.to_hf: |
|
|
suffix = 'hf' |
|
|
else: |
|
|
return |
|
|
|
|
|
self.output_dir = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}') |
|
|
|
|
|
self.output_dir = to_abspath(self.output_dir) |
|
|
if not self.exist_ok and os.path.exists(self.output_dir): |
|
|
raise FileExistsError(f'args.output_dir: `{self.output_dir}` already exists.') |
|
|
logger.info(f'args.output_dir: `{self.output_dir}`') |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.quant_batch_size == -1: |
|
|
self.quant_batch_size = None |
|
|
if self.quant_bits and self.quant_method is None: |
|
|
raise ValueError('Please specify the quantization method using `--quant_method awq/gptq/bnb`.') |
|
|
if self.quant_method and self.quant_bits is None: |
|
|
raise ValueError('Please specify `--quant_bits`.') |
|
|
if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None: |
|
|
self.torch_dtype = torch.float16 |
|
|
if self.to_mcore or self.to_hf: |
|
|
self.mcore_model = to_abspath(self.mcore_model, check_path_exist=True) |
|
|
if not dist.is_initialized(): |
|
|
set_default_ddp_config() |
|
|
init_process_group() |
|
|
|
|
|
BaseArguments.__post_init__(self) |
|
|
self._init_output_dir() |
|
|
if self.quant_method in {'gptq', 'awq'} and len(self.dataset) == 0: |
|
|
raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.') |
|
|
|