File size: 2,682 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional
import torch
from peft import IA3Config, PeftModel, get_peft_model
from swift.llm import MODEL_ARCH_MAPPING, ModelKeys
from swift.utils import find_all_linears
class Tuner:
@staticmethod
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
"""Prepare a new model with a tuner
Args:
args: The training arguments
model: The model instance
Returns:
The wrapped model
"""
raise NotImplementedError
@staticmethod
def save_pretrained(
model: torch.nn.Module,
save_directory: str,
state_dict: Optional[dict] = None,
safe_serialization: bool = True,
**kwargs,
) -> None:
"""Save when save_steps reaches
Args:
model: The wrapped model by `prepare_model`
save_directory: The directory to save
safe_serialization: Use safetensors or not
"""
raise NotImplementedError
@staticmethod
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
"""Load the ckpt_dir
Args:
model: The original model instance.
model_id: The model id or ckpt_dir to load
Returns:
The wrapped model instance
"""
raise NotImplementedError
class PeftTuner(Tuner):
@staticmethod
def save_pretrained(
model: torch.nn.Module,
save_directory: str,
state_dict: Optional[dict] = None,
safe_serialization: bool = True,
**kwargs,
) -> None:
model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs)
@staticmethod
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
return PeftModel.from_pretrained(model, model_id, **kwargs)
# Here gives a simple example of IA3
class IA3(PeftTuner):
@staticmethod
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
ia3_config = IA3Config(
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
return get_peft_model(model, ia3_config)
class DummyTuner(PeftTuner):
@staticmethod
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
return model
# Add your own tuner here, use --train_type xxx to begin
extra_tuners = {'ia3': IA3, 'dummy': DummyTuner}
|