|
import copy |
|
import random |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
from .utils import parent_module, brackets_to_periods, EarlyStopMeter, EditingMeanAct |
|
import transformers |
|
import numpy as np |
|
from torch import Tensor |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.activations import ACT2FN |
|
from .merge import slerp, GTA, linear |
|
import torch.nn as nn |
|
import gc |
|
|
|
merge_dict = { |
|
'slerp': slerp(), |
|
'ties': GTA('magnitude', 'sum', normalize=True), |
|
'magnitude_norm': GTA('magnitude', None, normalize=True), |
|
'magnitude': GTA('magnitude', None, normalize=False), |
|
'sign': GTA(None, 'sum', normalize=True), |
|
'dare_ties': GTA('rescaled_random', 'sum'), |
|
'dare_linear': GTA('random', None), |
|
'linear': linear() |
|
} |
|
|
|
edit_history = [] |
|
merge_group_edit_history = [] |
|
|
|
def euc(query, key, config, act_mask=None, infer=False): |
|
|
|
|
|
act_fn = ACT2FN[config.hidden_act] |
|
l2_norm = torch.norm(act_fn(key) - act_fn(query), dim=-1) |
|
if infer and l2_norm.size(1) > 100: |
|
topk = torch.topk(l2_norm, k=1, largest=True) |
|
return topk.values.mean() |
|
|
|
if act_mask is not None: |
|
return torch.sum(l2_norm * act_mask, dim=1) / torch.sum(act_mask, dim=1) |
|
else: |
|
return torch.mean(l2_norm, dim=-1) |
|
|
|
|
|
class WISE(torch.nn.Module): |
|
def __init__(self, config, model, device): |
|
super(WISE, self).__init__() |
|
self.config = config |
|
self.model = model |
|
self.config = config |
|
if hasattr(self.model.config, 'hidden_act'): |
|
self.config.hidden_act = self.model.config.hidden_act |
|
elif hasattr(self.model.config, 'activation_function'): |
|
self.config.hidden_act = self.model.config.activation_function |
|
|
|
layer = config.inner_params[0] |
|
self.device = device |
|
self.adapter_layer = None |
|
self.original_layer = None |
|
|
|
|
|
suffixes = [".weight", ".bias"] |
|
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer |
|
|
|
for n, p in self.model.named_parameters(): |
|
p.requires_grad = False |
|
|
|
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel): |
|
conv1D = True |
|
else: |
|
conv1D = False |
|
|
|
|
|
self.edit_module = parent_module(self.model, brackets_to_periods(self.layer)) |
|
self.layer_name = self.layer.rsplit(".", 1)[-1] |
|
adapter_layer = getattr(self.edit_module, self.layer_name) |
|
|
|
if type(adapter_layer) is not WISEAdapter: |
|
setattr(self.edit_module, self.layer_name, WISEAdapter(config, adapter_layer, conv1D=conv1D)) |
|
self.original_layer = copy.deepcopy(adapter_layer) |
|
print(f"New weights successfully inserted into {layer}") |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
def __call__(self, **kwargs): |
|
if not self.config.retrieve: |
|
if hasattr(self.get_adapter_layer(), 'editing') and not self.get_adapter_layer().editing: |
|
|
|
if not self.get_adapter_layer().original_layer.weight.equal(self.get_adapter_layer().new_weight) and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq: |
|
self.get_adapter_layer().memory_weight.append(self.get_adapter_layer().new_weight) |
|
if len(self.get_adapter_layer().memory_weight) > 0 and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq: |
|
print('length of memory is ', len(self.get_adapter_layer().memory_weight), '!!!!!!') |
|
self.get_adapter_layer().merge_weight() |
|
return self.model(**kwargs) |
|
|
|
def reset_layer(self): |
|
layer = getattr(self.edit_module, self.layer_name) |
|
del layer |
|
setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer) |
|
|
|
def get_adapter_layer(self): |
|
adapter_layer = getattr(self.edit_module, self.layer_name) |
|
assert type(adapter_layer) is WISEAdapter, print('Adapter Layer is not added correctly....') |
|
return adapter_layer |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
setattr(eval(f"self.model.{self.layer}"), "key_id", -1) |
|
return self.model.generate(*args, **kwargs) |
|
|
|
def edit(self, config, tokens, act_mask=None, deact_mask=None): |
|
|
|
global edit_history |
|
global merge_group_edit_history |
|
edit_history.append([{f"{k1}" : v1.to('cpu') for k1, v1 in tokens.items()}, False]) |
|
|
|
last_prompt_token_loc = (tokens["labels"] == -100).sum(dim=-1) - 1 |
|
|
|
setattr(eval(f"self.model.{self.layer}"), "training", True) |
|
setattr(eval(f"self.model.{self.layer}"), "editing", True) |
|
self.get_adapter_layer().set_parameter_tunable() |
|
if getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") % self.config.save_freq == 0: |
|
self.get_adapter_layer().generate_activation_mask(self.config.mask_ratio) |
|
|
|
|
|
loss_meter = EarlyStopMeter() |
|
for i in range(config.n_iter): |
|
|
|
if i == 0: |
|
|
|
optimizer = torch.optim.SGD([self.get_adapter_layer().new_weight], config.edit_lr, weight_decay=1e-5) |
|
|
|
ft_loss = self.__cal_ft_loss(tokens, last_prompt_token_loc) |
|
|
|
act_loss = self.__cal_activation_loss(self.get_adapter_layer().original_layer_output, self.get_adapter_layer().new_weight_layer_output, |
|
config=config, act_mask=act_mask, deact_mask=deact_mask) |
|
loss = ft_loss + act_loss.to(ft_loss.device) |
|
|
|
if loss_meter.stop(): |
|
self.get_adapter_layer().save_editing_activation() |
|
break |
|
if i == config.n_iter - 1: |
|
self.get_adapter_layer().save_editing_activation() |
|
|
|
if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay: |
|
memory_loss = [] |
|
for _ in merge_group_edit_history: |
|
idx = 0 |
|
while True: |
|
memo_input, is_used = _[idx] |
|
if not is_used: |
|
_[idx][1] = True |
|
break |
|
idx += 1 |
|
if idx == len(_): |
|
for m in range(len(_)): |
|
_[m][1] = False |
|
idx = 0 |
|
|
|
memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()} |
|
self.model(**memo_input) |
|
|
|
memory_act_loss = self.__cal_memory_neg_activation_loss(self.get_adapter_layer().original_layer_output, |
|
self.get_adapter_layer().new_weight_layer_output, config=config, |
|
act_mask=act_mask, deact_mask=deact_mask) |
|
memory_loss.append(memory_act_loss.to(ft_loss.device)) |
|
del memo_input |
|
neg_memo_loss = torch.stack(memory_loss).mean() |
|
loss += neg_memo_loss |
|
if len(edit_history) > 0: |
|
memo_input = random.choice(edit_history)[0] |
|
memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()} |
|
self.model(**memo_input) |
|
|
|
pos_memo_loss = self.__cal_memory_pos_activation_loss(self.get_adapter_layer().original_layer_output, |
|
self.get_adapter_layer().new_weight_layer_output, config=config, |
|
act_mask=act_mask, deact_mask=deact_mask) |
|
del memo_input |
|
loss += pos_memo_loss.to(ft_loss.device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
self.get_adapter_layer().mask_new_weight_gradient() |
|
|
|
if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay: |
|
print( |
|
f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)} + {np.round(neg_memo_loss.item(), 3)} + {np.round(pos_memo_loss.item(), 3)}" |
|
) |
|
else: |
|
print( |
|
f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)}" |
|
) |
|
|
|
optimizer.step() |
|
loss_meter.update(loss.item()) |
|
|
|
if type(self.config.norm_constraint) is float: |
|
self.__norm_constraint(self.config.norm_constraint) |
|
|
|
|
|
setattr(eval(f"self.model.{self.layer}"), "editing", False) |
|
setattr(eval(f"self.model.{self.layer}"), "training", False) |
|
|
|
editing_total_cnt = getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") + 1 |
|
setattr(eval(f"self.model.{self.layer}"), "editing_total_cnt", editing_total_cnt) |
|
|
|
if self.config.save_freq is not None and editing_total_cnt % self.config.save_freq == 0: |
|
self.get_adapter_layer().save_weight() |
|
print(f'Add New Weight to Memory...') |
|
if editing_total_cnt % self.config.merge_freq == 0: |
|
|
|
merge_group_edit_history.append(edit_history) |
|
edit_history = [] |
|
|
|
|
|
self.get_adapter_layer().merge_weight() |
|
print(f'Merge Weight of (New, Original) Matrix... with {self.config.merge_alg}') |
|
|
|
def __norm_constraint(self, norm_constraint): |
|
new_weight = self.get_adapter_layer().new_weight |
|
original_weight = self.get_adapter_layer().weight |
|
with torch.no_grad(): |
|
new_weight[...] = torch.clamp( |
|
new_weight, min=original_weight - norm_constraint, max=original_weight + norm_constraint |
|
) |
|
|
|
def __cal_ft_loss(self, tokens, last_prompt_token_loc): |
|
k = 1 |
|
bs = tokens["input_ids"].shape[0] - k |
|
logits = self.model(**tokens).logits |
|
shift_logits = logits[:-k, :-1, :].contiguous() |
|
shift_labels = tokens['labels'][:-k, 1:].contiguous() |
|
|
|
|
|
|
|
|
|
label_mask = torch.zeros_like(shift_labels, dtype=torch.bool) |
|
|
|
for i, col_index in enumerate(last_prompt_token_loc[:-k]): |
|
label_mask[i, col_index-1:] = True |
|
|
|
shift_labels[~label_mask] = -100 |
|
|
|
log_probs = -nn.functional.log_softmax(shift_logits, dim=-1) |
|
|
|
if shift_labels.dim() == log_probs.dim() - 1: |
|
shift_labels = shift_labels.unsqueeze(-1) |
|
|
|
padding_mask = shift_labels.eq(-100) |
|
|
|
|
|
|
|
shift_labels = torch.clamp(shift_labels, min=0) |
|
|
|
nll_loss = log_probs.gather(dim=-1, index=shift_labels) |
|
nll_loss.masked_fill_(padding_mask, 0.0) |
|
|
|
num_active_elements = padding_mask.numel() - padding_mask.long().sum() |
|
nll_loss = nll_loss.sum() / num_active_elements |
|
|
|
return nll_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __cal_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None, |
|
deact_mask=None): |
|
k = 1 |
|
if act_mask is not None: |
|
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config, |
|
act_mask=act_mask) |
|
out_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config, |
|
act_mask=deact_mask) |
|
else: |
|
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config) |
|
out_scope_dist = euc(original_layer_output[-k:, ...], new_weight_layer_output[-k:, ...], config) |
|
|
|
loss = out_scope_dist.view(-1,1) - in_scope_dist + config.gamma |
|
loss2 = out_scope_dist - config.alpha |
|
loss3 = config.beta - in_scope_dist |
|
loss3 = torch.mean(loss3[loss3 > 0]) if min(loss3[loss3 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device) |
|
loss2 = torch.mean(loss2[loss2 > 0]) if min(loss2[loss2 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device) |
|
loss = torch.mean(loss[loss > 0]) if min(loss[loss > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device) |
|
return loss + loss2 + loss3 |
|
|
|
def __cal_memory_pos_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None, |
|
deact_mask=None): |
|
k = 1 |
|
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config) |
|
loss4 = 20 - in_scope_dist |
|
|
|
return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.) |
|
|
|
def __cal_memory_neg_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None, |
|
deact_mask=None): |
|
k = 1 |
|
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config) |
|
loss4 = in_scope_dist - 5 |
|
|
|
return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.) |
|
|
|
class WISEAdapter(torch.nn.Module): |
|
def __init__(self, config, layer, conv1D): |
|
super(WISEAdapter, self).__init__() |
|
|
|
self.layer = layer |
|
self.weight = self.layer.weight |
|
self.device = layer.weight.device |
|
self.config = config |
|
self.new_weight = copy.deepcopy(self.weight) |
|
self.original_layer = copy.deepcopy(self.layer) |
|
self.memory_weight = [] |
|
self.memory_mean_act = [] |
|
self.merge_cnt = 0 |
|
assert not self.weight.requires_grad, print('Original Layer can not be tunable....') |
|
|
|
self.used_mask = None |
|
|
|
self.training = False |
|
self.editing = False |
|
self.conv1D = conv1D |
|
|
|
self.editing_mean_act = EditingMeanAct() |
|
self.editing_total_cnt = 0 |
|
|
|
def set_parameter_tunable(self): |
|
self.new_weight.requires_grad = True |
|
|
|
def save_weight(self): |
|
self.memory_weight.append(copy.deepcopy(self.new_weight)) |
|
self.new_weight = copy.deepcopy(self.original_layer.weight) |
|
if self.config.retrieve: |
|
self.memory_mean_act.append(copy.deepcopy(self.editing_mean_act)) |
|
self.editing_mean_act = EditingMeanAct() |
|
|
|
def merge_weight(self): |
|
if self.config.save_freq is not None: |
|
if not self.config.retrieve: |
|
merge_alg = merge_dict[self.config.merge_alg] |
|
if self.original_layer.weight.equal(self.layer.weight): |
|
cur_new_weight = merge_alg.execute([self.config.weights / len(self.memory_weight) for _ in range(len(self.memory_weight))], self.original_layer.weight, self.memory_weight, densities=self.config.densities) |
|
else: |
|
cur_new_weight = merge_alg.execute([0.4 / len(self.memory_weight) for _ in range(len(self.memory_weight))] + [0.6], self.original_layer.weight, self.memory_weight + [self.layer.weight], densities=self.config.densities) |
|
self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False) |
|
self.new_weight = copy.deepcopy(self.original_layer.weight) |
|
del self.memory_weight |
|
self.memory_weight = [] |
|
else: |
|
merge_alg = merge_dict[self.config.merge_alg] |
|
merge_num = self.config.merge_freq // self.config.save_freq |
|
assert len(self.memory_weight) >= merge_num |
|
new_merge_weight = merge_alg.execute([self.config.weights / merge_num for _ in range(merge_num)], self.original_layer.weight, self.memory_weight[-merge_num:], densities=self.config.densities) |
|
min_a = 1e9 |
|
for _ in range(merge_num): |
|
self.memory_weight.pop() |
|
edit_act = self.memory_mean_act.pop() |
|
min_a = min(min_a, edit_act.min_act()) |
|
self.new_weight = copy.deepcopy(self.original_layer.weight) |
|
self.memory_weight.append(new_merge_weight) |
|
self.memory_mean_act.append(EditingMeanAct(min_a=min_a)) |
|
print(len(self.memory_weight)) |
|
assert len(self.memory_mean_act) == len(self.memory_weight) |
|
self.merge_cnt += 1 |
|
else: |
|
merge_alg = merge_dict[self.config.merge_alg] |
|
cur_new_weight = merge_alg.execute(0.5, self.layer.weight, [self.new_weight], |
|
densities=self.config.densities) |
|
self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False) |
|
self.new_weight = copy.deepcopy(self.original_layer.weight) |
|
|
|
def save_editing_activation(self): |
|
in_scope_dist = euc(self.original_layer_output[:-1, ...], self.new_weight_layer_output[:-1, ...], self.config) |
|
self.editing_mean_act.update(in_scope_dist.mean().item()) |
|
|
|
def generate_activation_mask(self, mask_ratio): |
|
p_grad = self.new_weight.reshape(-1) |
|
p_mask = np.random.choice([1, 0], size=p_grad.size()[0], p=[mask_ratio, 1 - mask_ratio]) |
|
p_mask = torch.from_numpy(p_mask).to(p_grad.device) |
|
self.weight_mask = p_mask |
|
|
|
def generate_non_overlapping_mask(self, mask_ratio): |
|
p_grad = self.new_weight.reshape(-1) |
|
mask_size = int(mask_ratio * p_grad.size()[0]) |
|
if self.used_mask is None: |
|
self.used_mask = np.zeros(p_grad.size()[0], dtype=bool) |
|
available_indices = np.where(~self.used_mask)[0] |
|
if len(available_indices) < mask_size: |
|
raise ValueError("Not enough unused elements to generate a new mask.") |
|
chosen_indices = np.random.choice(available_indices, size=mask_size, replace=False) |
|
mask_array = np.zeros(p_grad.size()[0], dtype=int) |
|
mask_array[chosen_indices] = 1 |
|
self.used_mask[chosen_indices] = True |
|
self.weight_mask = torch.from_numpy(mask_array).to(p_grad.device) |
|
|
|
def new_weight_forward(self, input: Tensor, weight) -> Tensor: |
|
if self.conv1D: |
|
size_out = input.size()[:-1] + (weight.size(1),) |
|
input = torch.addmm(self.original_layer.bias, input.view(-1, input.size(-1)), weight) |
|
input = input.view(size_out) |
|
return input |
|
else: |
|
return F.linear(input, weight) |
|
|
|
def mask_new_weight_gradient(self): |
|
assert self.new_weight.grad is not None, print('Gradient Collection for New Weight error, gradient not found') |
|
|
|
p_size = self.new_weight.grad.size() |
|
p_grad = self.new_weight.grad.reshape(-1) |
|
|
|
|
|
p_grad = p_grad * self.weight_mask |
|
self.new_weight.grad = p_grad.view(p_size).to(self.new_weight.grad.dtype) |
|
|
|
def forward(self, *args): |
|
if self.editing: |
|
layer_out = self.new_weight_forward(*args, self.new_weight) |
|
self.new_weight_layer_output = layer_out |
|
self.original_layer_output = self.original_layer(*args) |
|
else: |
|
if not self.config.retrieve: |
|
original_layer_output = self.original_layer(*args) |
|
layer_output = self.layer(*args) |
|
new_weight_layer_output = self.new_weight_forward(*args, self.new_weight) |
|
dist2 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True) |
|
dist1 = euc(original_layer_output, layer_output, self.config, infer=True) |
|
threshold = self.editing_mean_act.min_act() * self.config.act_ratio |
|
|
|
if dist1.item() < threshold and dist2.item() < threshold: |
|
layer_out = original_layer_output |
|
elif dist1.item() > dist2.item(): |
|
layer_out = layer_output |
|
else: |
|
layer_out = new_weight_layer_output |
|
else: |
|
original_layer_output = self.original_layer(*args) |
|
new_weight_layer_output = self.new_weight_forward(*args, self.new_weight) |
|
dist1 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True) |
|
threshold = self.editing_mean_act.min_act() * self.config.act_ratio |
|
min_dist = dist1 |
|
if min_dist.item() < threshold: |
|
layer_out = original_layer_output |
|
else: |
|
layer_out = new_weight_layer_output |
|
|
|
for i in range(len(self.memory_weight)): |
|
memory_retrieve_weight = self.memory_weight[i] |
|
memory_weight_layer_output = self.new_weight_forward(*args, memory_retrieve_weight) |
|
dist = euc(original_layer_output, memory_weight_layer_output, self.config, infer=True) |
|
if dist > min_dist and dist > self.memory_mean_act[i].min_act() * self.config.act_ratio: |
|
layer_out = memory_weight_layer_output |
|
min_dist = dist |
|
print(dist, self.memory_mean_act[i].min_act() * self.config.act_ratio) |
|
return layer_out |