ZekunXi's picture
Add application file
8124a18
raw
history blame
1.08 kB
from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .GRACE import GRACE
from .grace_hparams import GraceHyperParams
from .utils import tokenize
from ...util import nethook
def apply_grace_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: GraceHyperParams,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
model.to(f'cuda:{hparams.device}')
request = requests
if copy:
model = deepcopy(model)
weights_copy = {}
device = torch.device(f'cuda:{hparams.device}')
editor = GRACE(model=model, config=hparams, device=device)
tokens = tokenize(request, tokenizer=tok, device=device)
editor.edit(config=hparams, tokens=tokens)
if not keep_original_weight:
weights_copy = {}
editor.to(f'cuda:{hparams.device}')
return editor, weights_copy