ZJUPeng's picture
add continuous
d6682b6
raw
history blame
1.07 kB
from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .GRACE import GRACE
from .grace_hparams import GraceHyperParams
from .utils import tokenize
from ...util import nethook
import gradio as gr
def apply_grace_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: GraceHyperParams,
num_steps: int,
edit_lr: float,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
request = requests
if copy:
model = deepcopy(model)
weights_copy = {}
device = torch.device('cpu')
hparams.edit_lr = edit_lr
editor = GRACE(model=model, config=hparams, device=device)
tokens = tokenize(request, tokenizer=tok, device=device)
editor.edit(config=hparams, tokens=tokens)
# editor.to('cpu')
gr.Info("Completed editing via GRACE!")
return editor