|
|
|
from dataclasses import dataclass |
|
from types import MethodType |
|
from typing import List, Literal, Optional |
|
|
|
import json |
|
import torch |
|
from torch import nn |
|
|
|
from swift.utils import get_logger, patch_getattr |
|
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class ReftConfig(SwiftConfig): |
|
""" |
|
Train a model with Reft. |
|
Paper: https://arxiv.org/pdf/2404.03592 |
|
|
|
Args: |
|
model_type(`Optional[str]`): The model_type to find down_proj/layers. |
|
layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`. |
|
layers (`Optional[List[int]]`): The layer number to inject. |
|
r(`int`): The rank of Reft. |
|
intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention', |
|
'ConsreftIntervention', 'LobireftIntervention', |
|
'DireftIntervention', 'NodireftIntervention']`): The intervention type, |
|
default LoreftIntervention |
|
args (`Optional[str]`): Other reft_args in json-string format |
|
""" |
|
|
|
model_type: Optional[str] = None |
|
layer_key: Optional[str] = None |
|
layers: Optional[List[int]] = None |
|
r: int = 4 |
|
intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention', |
|
'LobireftIntervention', 'DireftIntervention', |
|
'NodireftIntervention'] = 'LoreftIntervention' |
|
args: Optional[str] = None |
|
|
|
def __post_init__(self): |
|
from .mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.REFT |
|
if self.args: |
|
self.args = json.loads(self.args) |
|
else: |
|
self.args = {} |
|
|
|
|
|
class Reft(SwiftAdapter): |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str): |
|
from swift.utils.import_utils import is_pyreft_available |
|
if not is_pyreft_available(): |
|
raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`') |
|
|
|
import pyreft |
|
from pyreft import ReftModel |
|
from pyreft.interventions import LowRankRotateLayer |
|
from pyreft import ( |
|
NoreftIntervention, |
|
LoreftIntervention, |
|
ConsreftIntervention, |
|
LobireftIntervention, |
|
DireftIntervention, |
|
NodireftIntervention, |
|
) |
|
|
|
intervention_mapping = { |
|
'NoreftIntervention': NoreftIntervention, |
|
'LoreftIntervention': LoreftIntervention, |
|
'ConsreftIntervention': ConsreftIntervention, |
|
'LobireftIntervention': LobireftIntervention, |
|
'DireftIntervention': DireftIntervention, |
|
'NodireftIntervention': NodireftIntervention, |
|
} |
|
|
|
patch_getattr(ReftModel, 'model') |
|
|
|
def forward(self, x): |
|
self.to(x.device) |
|
return self.forward_origin(x) |
|
|
|
def forward2(self, base, source=None, subspaces=None): |
|
self.to(base.device) |
|
return self.forward_origin(base, source, subspaces) |
|
|
|
if not hasattr(LowRankRotateLayer, 'forward_origin'): |
|
LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward |
|
LowRankRotateLayer.forward = forward |
|
NoreftIntervention.forward_origin = NoreftIntervention.forward |
|
NoreftIntervention.forward = forward2 |
|
LoreftIntervention.forward_origin = LoreftIntervention.forward |
|
LoreftIntervention.forward = forward2 |
|
ConsreftIntervention.forward_origin = ConsreftIntervention.forward |
|
ConsreftIntervention.forward = forward2 |
|
LobireftIntervention.forward_origin = LobireftIntervention.forward |
|
LobireftIntervention.forward = forward2 |
|
DireftIntervention.forward_origin = DireftIntervention.forward |
|
DireftIntervention.forward = forward2 |
|
NodireftIntervention.forward_origin = NodireftIntervention.forward |
|
NodireftIntervention.forward = forward2 |
|
|
|
module_list_key = config.layer_key |
|
if module_list_key is None: |
|
model_key_mapping = Reft.get_model_key_mapping(config.model_type, config) |
|
module_list_key = model_key_mapping.module_list |
|
logger.info(f'Applying Reft to module: {module_list_key}') |
|
module_list: nn.ModuleList = model.get_submodule(module_list_key) |
|
representations = [] |
|
for idx, layer in enumerate(module_list): |
|
if config.layers and idx not in config.layers: |
|
continue |
|
intervention_config = { |
|
'layer': |
|
idx, |
|
'component': |
|
module_list_key + f'[{idx}].output', |
|
'low_rank_dimension': |
|
config.r, |
|
'intervention': |
|
intervention_mapping[config.intervention_type]( |
|
embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args) |
|
} |
|
representations.append(intervention_config) |
|
|
|
reft_config = pyreft.ReftConfig(representations=representations) |
|
reft_model = pyreft.get_reft_model(model, reft_config, set_device=False) |
|
reft_model.reft_config = reft_model.config |
|
reft_model.config = reft_model.model.config |
|
|
|
def _pre_forward_hook(module, args, kwargs): |
|
if 'base' in kwargs: |
|
return args, kwargs |
|
|
|
if 'input_ids' not in kwargs: |
|
raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.') |
|
|
|
unit_locations = None |
|
if 'intervention_locations' in kwargs: |
|
if kwargs['intervention_locations'].dim() == 3: |
|
unit_locations = { |
|
'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) |
|
} |
|
else: |
|
|
|
unit_locations = {'sources->base': (None, 0)} |
|
kwargs = { |
|
'base': { |
|
'input_ids': kwargs['input_ids'], |
|
'attention_mask': kwargs['attention_mask'] |
|
}, |
|
'unit_locations': unit_locations, |
|
'labels': kwargs['labels'], |
|
'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None |
|
} |
|
return args, kwargs |
|
|
|
def _post_forward_hook(module, args, kwargs, outputs): |
|
return outputs[1] |
|
|
|
def _generate(self, **kwargs): |
|
|
|
unit_locations = None |
|
if 'intervention_locations' in kwargs: |
|
if kwargs['intervention_locations'].dim() == 3: |
|
unit_locations = { |
|
'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) |
|
} |
|
else: |
|
|
|
unit_locations = {'sources->base': (None, 0)} |
|
|
|
_kwargs = { |
|
'base': { |
|
'input_ids': kwargs.pop('input_ids'), |
|
'attention_mask': kwargs.pop('attention_mask') |
|
}, |
|
'unit_locations': unit_locations, |
|
'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None |
|
} |
|
_kwargs = {**_kwargs, **kwargs} |
|
return self.generate_origin(**_kwargs)[1] |
|
|
|
reft_model.generate_origin = reft_model.generate |
|
reft_model.generate = MethodType(_generate, reft_model) |
|
reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) |
|
reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True) |
|
|
|
def save_callback(swift_model, model_dir, adapter_name): |
|
reft_model.save_intervention(save_directory=model_dir, include_model=False) |
|
|
|
def mark_trainable_callback(model): |
|
return |
|
|
|
def load_callback(swift_model, model_dir, adapter_name): |
|
reft_model.load_intervention(model_dir, include_model=False) |
|
|
|
return SwiftOutput( |
|
model=reft_model, |
|
config=config, |
|
mark_trainable_callback=mark_trainable_callback, |
|
save_callback=save_callback, |
|
load_callback=load_callback) |
|
|
|
@staticmethod |
|
def has_additional_modules(): |
|
return True |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
assert activate, 'ReFT does not support deactivate' |
|
|