|
|
|
|
|
from typing import Optional |
|
|
|
import fire |
|
import torch |
|
import tqdm |
|
import transformers |
|
|
|
|
|
@torch.inference_mode() |
|
def make_diff( |
|
path_raw: str, path_tuned: str, path_diff: str, device="cpu", |
|
): |
|
"""Make the weight diff. |
|
|
|
This function is given to present full transparency of how the weight diff was created. |
|
|
|
Run: |
|
python weight_diff.py make_diff --path_raw <your_path_raw> --path_tuned <your_path_tuned> --path_diff <your_path_diff> |
|
""" |
|
model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( |
|
path_tuned, |
|
device_map={"": torch.device(device)}, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
) |
|
model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( |
|
path_raw, |
|
device_map={"": torch.device(device)}, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
) |
|
|
|
tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( |
|
path_tuned |
|
) |
|
tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( |
|
path_raw |
|
) |
|
|
|
tokenizer_tuned.pad_token_id = ( |
|
0 |
|
) |
|
tokenizer_tuned.padding_side = "left" |
|
|
|
state_dict_tuned = model_tuned.state_dict() |
|
state_dict_raw = model_raw.state_dict() |
|
for key in tqdm.tqdm(state_dict_tuned): |
|
state_dict_tuned[key].add_(-state_dict_raw[key]) |
|
|
|
model_tuned.save_pretrained(path_diff) |
|
tokenizer_tuned.save_pretrained(path_diff) |
|
|
|
|
|
@torch.inference_mode() |
|
def recover( |
|
path_raw, |
|
path_diff, |
|
path_tuned: Optional[str] = None, |
|
device="cpu", |
|
test_inference=True |
|
): |
|
"""Recover the original weights from the released weight diff. |
|
|
|
This function is given for you to run. |
|
|
|
Things to do before running this: |
|
1. Convert Meta's released weights into huggingface format. Follow this guide: |
|
https://huggingface.co/docs/transformers/main/model_doc/llama |
|
You may refer to https://huggingface.co/huggyllama/llama-7b if you get some trouble in the conversion. (You should only use this repository if you have been granted access to the llama model.) |
|
2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at: |
|
https://huggingface.co/Dynosaur/dynosaur-llama-7b-superni |
|
3. Run this function with the correct paths. E.g., |
|
python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir> |
|
|
|
Additional notes: |
|
- If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`. |
|
- If you want to save the recovered weights, set `--path_tuned <your_path_tuned>`. |
|
Next time you can load the recovered weights directly from `<your_path_tuned>`. |
|
""" |
|
model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( |
|
path_raw, |
|
device_map={"": torch.device(device)}, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
) |
|
model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( |
|
path_diff, |
|
device_map={"": torch.device(device)}, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
) |
|
|
|
tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( |
|
path_raw |
|
) |
|
tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( |
|
path_diff |
|
) |
|
|
|
state_dict_recovered = model_recovered.state_dict() |
|
state_dict_raw = model_raw.state_dict() |
|
for key in tqdm.tqdm(state_dict_recovered): |
|
state_dict_recovered[key].add_(state_dict_raw[key]) |
|
|
|
if path_tuned is not None: |
|
model_recovered.save_pretrained(path_tuned) |
|
tokenizer_recovered.save_pretrained(path_tuned) |
|
|
|
if test_inference: |
|
input_text = ( |
|
"Below is an instruction that describes a task. " |
|
"Write a response that appropriately completes the request.\r\n\r\n" |
|
"### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:" |
|
) |
|
inputs = tokenizer_recovered(input_text, return_tensors="pt") |
|
out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100) |
|
output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0] |
|
output_text = output_text[len(input_text) :] |
|
print(f"Input: {input_text}\nCompletion: {output_text}") |
|
|
|
return model_recovered, tokenizer_recovered |
|
|
|
|
|
def main(task, **kwargs): |
|
globals()[task](**kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|