Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional
import torch
from peft import IA3Config, PeftModel, get_peft_model
from swift.llm import MODEL_ARCH_MAPPING, ModelKeys
from swift.utils import find_all_linears
class Tuner:
@staticmethod
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
"""Prepare a new model with a tuner
Args:
args: The training arguments
model: The model instance
Returns:
The wrapped model
"""
raise NotImplementedError
@staticmethod
def save_pretrained(
model: torch.nn.Module,
save_directory: str,
state_dict: Optional[dict] = None,
safe_serialization: bool = True,
**kwargs,
) -> None:
"""Save when save_steps reaches
Args:
model: The wrapped model by `prepare_model`
save_directory: The directory to save
safe_serialization: Use safetensors or not
"""
raise NotImplementedError
@staticmethod
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
"""Load the ckpt_dir
Args:
model: The original model instance.
model_id: The model id or ckpt_dir to load
Returns:
The wrapped model instance
"""
raise NotImplementedError
class PeftTuner(Tuner):
@staticmethod
def save_pretrained(
model: torch.nn.Module,
save_directory: str,
state_dict: Optional[dict] = None,
safe_serialization: bool = True,
**kwargs,
) -> None:
model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs)
@staticmethod
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
return PeftModel.from_pretrained(model, model_id, **kwargs)
# Here gives a simple example of IA3
class IA3(PeftTuner):
@staticmethod
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
ia3_config = IA3Config(
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
return get_peft_model(model, ia3_config)
class DummyTuner(PeftTuner):
@staticmethod
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
return model
# Add your own tuner here, use --train_type xxx to begin
extra_tuners = {'ia3': IA3, 'dummy': DummyTuner}