Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from types import MethodType
from typing import List, Literal, Optional
import json
import torch
from torch import nn
from swift.utils import get_logger, patch_getattr
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class ReftConfig(SwiftConfig):
"""
Train a model with Reft.
Paper: https://arxiv.org/pdf/2404.03592
Args:
model_type(`Optional[str]`): The model_type to find down_proj/layers.
layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`.
layers (`Optional[List[int]]`): The layer number to inject.
r(`int`): The rank of Reft.
intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention',
'ConsreftIntervention', 'LobireftIntervention',
'DireftIntervention', 'NodireftIntervention']`): The intervention type,
default LoreftIntervention
args (`Optional[str]`): Other reft_args in json-string format
"""
model_type: Optional[str] = None
layer_key: Optional[str] = None
layers: Optional[List[int]] = None
r: int = 4
intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
'LobireftIntervention', 'DireftIntervention',
'NodireftIntervention'] = 'LoreftIntervention'
args: Optional[str] = None
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.REFT
if self.args:
self.args = json.loads(self.args)
else:
self.args = {}
class Reft(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str):
from swift.utils.import_utils import is_pyreft_available
if not is_pyreft_available():
raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`')
import pyreft
from pyreft import ReftModel
from pyreft.interventions import LowRankRotateLayer
from pyreft import (
NoreftIntervention,
LoreftIntervention,
ConsreftIntervention,
LobireftIntervention,
DireftIntervention,
NodireftIntervention,
)
intervention_mapping = {
'NoreftIntervention': NoreftIntervention,
'LoreftIntervention': LoreftIntervention,
'ConsreftIntervention': ConsreftIntervention,
'LobireftIntervention': LobireftIntervention,
'DireftIntervention': DireftIntervention,
'NodireftIntervention': NodireftIntervention,
}
patch_getattr(ReftModel, 'model')
def forward(self, x):
self.to(x.device)
return self.forward_origin(x)
def forward2(self, base, source=None, subspaces=None):
self.to(base.device)
return self.forward_origin(base, source, subspaces)
if not hasattr(LowRankRotateLayer, 'forward_origin'):
LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward
LowRankRotateLayer.forward = forward
NoreftIntervention.forward_origin = NoreftIntervention.forward
NoreftIntervention.forward = forward2
LoreftIntervention.forward_origin = LoreftIntervention.forward
LoreftIntervention.forward = forward2
ConsreftIntervention.forward_origin = ConsreftIntervention.forward
ConsreftIntervention.forward = forward2
LobireftIntervention.forward_origin = LobireftIntervention.forward
LobireftIntervention.forward = forward2
DireftIntervention.forward_origin = DireftIntervention.forward
DireftIntervention.forward = forward2
NodireftIntervention.forward_origin = NodireftIntervention.forward
NodireftIntervention.forward = forward2
module_list_key = config.layer_key
if module_list_key is None:
model_key_mapping = Reft.get_model_key_mapping(config.model_type, config)
module_list_key = model_key_mapping.module_list
logger.info(f'Applying Reft to module: {module_list_key}')
module_list: nn.ModuleList = model.get_submodule(module_list_key)
representations = []
for idx, layer in enumerate(module_list):
if config.layers and idx not in config.layers:
continue
intervention_config = {
'layer':
idx,
'component':
module_list_key + f'[{idx}].output',
'low_rank_dimension':
config.r,
'intervention':
intervention_mapping[config.intervention_type](
embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args)
}
representations.append(intervention_config)
reft_config = pyreft.ReftConfig(representations=representations)
reft_model = pyreft.get_reft_model(model, reft_config, set_device=False)
reft_model.reft_config = reft_model.config
reft_model.config = reft_model.model.config
def _pre_forward_hook(module, args, kwargs):
if 'base' in kwargs:
return args, kwargs
if 'input_ids' not in kwargs:
raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.')
# run intervened forward pass
unit_locations = None
if 'intervention_locations' in kwargs:
if kwargs['intervention_locations'].dim() == 3:
unit_locations = {
'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
}
else:
# this is dummy for lora only baseline
unit_locations = {'sources->base': (None, 0)}
kwargs = {
'base': {
'input_ids': kwargs['input_ids'],
'attention_mask': kwargs['attention_mask']
},
'unit_locations': unit_locations,
'labels': kwargs['labels'],
'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
}
return args, kwargs
def _post_forward_hook(module, args, kwargs, outputs):
return outputs[1]
def _generate(self, **kwargs):
# run intervened forward pass
unit_locations = None
if 'intervention_locations' in kwargs:
if kwargs['intervention_locations'].dim() == 3:
unit_locations = {
'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
}
else:
# this is dummy for lora only baseline
unit_locations = {'sources->base': (None, 0)}
_kwargs = {
'base': {
'input_ids': kwargs.pop('input_ids'),
'attention_mask': kwargs.pop('attention_mask')
},
'unit_locations': unit_locations,
'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
}
_kwargs = {**_kwargs, **kwargs}
return self.generate_origin(**_kwargs)[1]
reft_model.generate_origin = reft_model.generate
reft_model.generate = MethodType(_generate, reft_model)
reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True)
def save_callback(swift_model, model_dir, adapter_name):
reft_model.save_intervention(save_directory=model_dir, include_model=False)
def mark_trainable_callback(model):
return
def load_callback(swift_model, model_dir, adapter_name):
reft_model.load_intervention(model_dir, include_model=False)
return SwiftOutput(
model=reft_model,
config=config,
mark_trainable_callback=mark_trainable_callback,
save_callback=save_callback,
load_callback=load_callback)
@staticmethod
def has_additional_modules():
return True
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
assert activate, 'ReFT does not support deactivate'