|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from peft import IA3Config, PeftModel, get_peft_model |
|
|
|
|
|
from swift.llm import MODEL_ARCH_MAPPING, ModelKeys |
|
|
from swift.utils import find_all_linears |
|
|
|
|
|
|
|
|
class Tuner: |
|
|
|
|
|
@staticmethod |
|
|
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: |
|
|
"""Prepare a new model with a tuner |
|
|
|
|
|
Args: |
|
|
args: The training arguments |
|
|
model: The model instance |
|
|
|
|
|
Returns: |
|
|
The wrapped model |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def save_pretrained( |
|
|
model: torch.nn.Module, |
|
|
save_directory: str, |
|
|
state_dict: Optional[dict] = None, |
|
|
safe_serialization: bool = True, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
"""Save when save_steps reaches |
|
|
|
|
|
Args: |
|
|
model: The wrapped model by `prepare_model` |
|
|
save_directory: The directory to save |
|
|
safe_serialization: Use safetensors or not |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module: |
|
|
"""Load the ckpt_dir |
|
|
|
|
|
Args: |
|
|
model: The original model instance. |
|
|
model_id: The model id or ckpt_dir to load |
|
|
Returns: |
|
|
The wrapped model instance |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class PeftTuner(Tuner): |
|
|
|
|
|
@staticmethod |
|
|
def save_pretrained( |
|
|
model: torch.nn.Module, |
|
|
save_directory: str, |
|
|
state_dict: Optional[dict] = None, |
|
|
safe_serialization: bool = True, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module: |
|
|
return PeftModel.from_pretrained(model, model_id, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class IA3(PeftTuner): |
|
|
|
|
|
@staticmethod |
|
|
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: |
|
|
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch] |
|
|
ia3_config = IA3Config( |
|
|
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*') |
|
|
return get_peft_model(model, ia3_config) |
|
|
|
|
|
|
|
|
class DummyTuner(PeftTuner): |
|
|
|
|
|
@staticmethod |
|
|
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
extra_tuners = {'ia3': IA3, 'dummy': DummyTuner} |
|
|
|