Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass
from typing import Literal, Optional
import torch
import torch.distributed as dist
from swift.utils import get_logger, init_process_group, set_default_ddp_config
from .base_args import BaseArguments, to_abspath
from .merge_args import MergeArguments
logger = get_logger()
@dataclass
class ExportArguments(MergeArguments, BaseArguments):
"""
ExportArguments is a dataclass that inherits from BaseArguments and MergeArguments.
Args:
output_dir (Optional[str]): Directory where the output will be saved.
quant_n_samples (int): Number of samples for quantization.
max_length (int): Sequence length for quantization.
quant_batch_size (int): Batch size for quantization.
to_ollama (bool): Flag to indicate export model to ollama format.
push_to_hub (bool): Flag to indicate if the output should be pushed to the model hub.
hub_model_id (Optional[str]): Model ID for the hub.
hub_private_repo (bool): Flag to indicate if the hub repository is private.
commit_message (str): Commit message for pushing to the hub.
to_peft_format (bool): Flag to indicate if the output should be in PEFT format.
This argument is useless for now.
"""
output_dir: Optional[str] = None
# awq/gptq
quant_method: Literal['awq', 'gptq', 'bnb'] = None
quant_n_samples: int = 256
max_length: int = 2048
quant_batch_size: int = 1
group_size: int = 128
# ollama
to_ollama: bool = False
# megatron
to_mcore: bool = False
to_hf: bool = False
mcore_model: Optional[str] = None
thread_count: Optional[int] = None
test_convert_precision: bool = False
# push to ms hub
push_to_hub: bool = False
# 'user_name/repo_name' or 'repo_name'
hub_model_id: Optional[str] = None
hub_private_repo: bool = False
commit_message: str = 'update files'
# compat
to_peft_format: bool = False
exist_ok: bool = False
def _init_output_dir(self):
if self.output_dir is None:
ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}'
ckpt_dir, ckpt_name = os.path.split(ckpt_dir)
if self.to_peft_format:
suffix = 'peft'
elif self.quant_method:
suffix = f'{self.quant_method}-int{self.quant_bits}'
elif self.to_ollama:
suffix = 'ollama'
elif self.merge_lora:
suffix = 'merged'
elif self.to_mcore:
suffix = 'mcore'
elif self.to_hf:
suffix = 'hf'
else:
return
self.output_dir = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}')
self.output_dir = to_abspath(self.output_dir)
if not self.exist_ok and os.path.exists(self.output_dir):
raise FileExistsError(f'args.output_dir: `{self.output_dir}` already exists.')
logger.info(f'args.output_dir: `{self.output_dir}`')
def __post_init__(self):
if self.quant_batch_size == -1:
self.quant_batch_size = None
if self.quant_bits and self.quant_method is None:
raise ValueError('Please specify the quantization method using `--quant_method awq/gptq/bnb`.')
if self.quant_method and self.quant_bits is None:
raise ValueError('Please specify `--quant_bits`.')
if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.to_mcore or self.to_hf:
self.mcore_model = to_abspath(self.mcore_model, check_path_exist=True)
if not dist.is_initialized():
set_default_ddp_config()
init_process_group()
BaseArguments.__post_init__(self)
self._init_output_dir()
if self.quant_method in {'gptq', 'awq'} and len(self.dataset) == 0:
raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.')