ZJUPeng's picture
add continuous
d6682b6
raw
history blame
1.39 kB
from typing import Any, Dict, List, Tuple
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .WISE import WISE
from .utils import tokenize, get_context_templates
from .wise_hparams import WISEHyperParams
import gradio as gr
def apply_wise_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
request: List[Dict],
hparams: WISEHyperParams,
num_steps: int,
edit_lr: float,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
if copy:
model = deepcopy(model)
weights_copy = {}
hparams.n_iter = num_steps
hparams.edit_lr = edit_lr
context_templates = get_context_templates(model, tok, length_params=[[5,5], [10,5]], device=hparams.device)
editor = WISE(model=model, config=hparams, device=hparams.device)
print(
f"Executing WISE algorithm for the update: "
f"[{request['prompt']}] -> [{request['target_new']}]"
)
tokens, act_mask, deact_mask = tokenize(request, tokenizer=tok, device=hparams.device, context_templates=context_templates, hparams=hparams)
editor.edit(config=hparams, tokens=tokens, act_mask=act_mask, deact_mask=deact_mask)
editor.to('cpu')
gr.Info("Completed editing via WISE!")
return editor