Upload weight_diff.py
Browse files- weight_diff.py +131 -0
weight_diff.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/main/weight_diff.py
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import fire
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
import transformers
|
9 |
+
|
10 |
+
|
11 |
+
@torch.inference_mode()
|
12 |
+
def make_diff(
|
13 |
+
path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu"
|
14 |
+
):
|
15 |
+
"""Make the weight diff.
|
16 |
+
|
17 |
+
This function is given to present full transparency of how the weight diff was created.
|
18 |
+
|
19 |
+
Run:
|
20 |
+
python weight_diff.py make_diff --path_raw <your_path_raw> --path_tuned <your_path_tuned> --path_diff <your_path_diff>
|
21 |
+
"""
|
22 |
+
model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
|
23 |
+
path_tuned,
|
24 |
+
device_map={"": torch.device(device)},
|
25 |
+
torch_dtype=torch.float32,
|
26 |
+
low_cpu_mem_usage=True,
|
27 |
+
)
|
28 |
+
model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
|
29 |
+
path_raw,
|
30 |
+
device_map={"": torch.device(device)},
|
31 |
+
torch_dtype=torch.float32,
|
32 |
+
low_cpu_mem_usage=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
36 |
+
path_tuned
|
37 |
+
)
|
38 |
+
tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
39 |
+
path_raw
|
40 |
+
)
|
41 |
+
|
42 |
+
tokenizer_tuned.pad_token_id = (
|
43 |
+
0 # unk. we want this to be different from the eos token
|
44 |
+
)
|
45 |
+
tokenizer_tuned.padding_side = "left" # Allow batched inference
|
46 |
+
|
47 |
+
state_dict_tuned = model_tuned.state_dict()
|
48 |
+
state_dict_raw = model_raw.state_dict()
|
49 |
+
for key in tqdm.tqdm(state_dict_tuned):
|
50 |
+
state_dict_tuned[key].add_(-state_dict_raw[key])
|
51 |
+
|
52 |
+
model_tuned.save_pretrained(path_diff)
|
53 |
+
tokenizer_tuned.save_pretrained(path_diff)
|
54 |
+
|
55 |
+
|
56 |
+
@torch.inference_mode()
|
57 |
+
def recover(
|
58 |
+
path_raw,
|
59 |
+
path_diff,
|
60 |
+
path_tuned: Optional[str] = None,
|
61 |
+
device="cpu",
|
62 |
+
test_inference=True
|
63 |
+
):
|
64 |
+
"""Recover the original weights from the released weight diff.
|
65 |
+
|
66 |
+
This function is given for you to run.
|
67 |
+
|
68 |
+
Things to do before running this:
|
69 |
+
1. Convert Meta's released weights into huggingface format. Follow this guide:
|
70 |
+
https://huggingface.co/docs/transformers/main/model_doc/llama
|
71 |
+
You may refer to https://huggingface.co/huggyllama/llama-7b if you get some trouble in the conversion. (You should only use this repository if you have been granted access to the llama model.)
|
72 |
+
2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
|
73 |
+
https://huggingface.co/Dynosaur/dynosaur-llama-7b-superni
|
74 |
+
3. Run this function with the correct paths. E.g.,
|
75 |
+
python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir>
|
76 |
+
|
77 |
+
Additional notes:
|
78 |
+
- If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
|
79 |
+
- If you want to save the recovered weights, set `--path_tuned <your_path_tuned>`.
|
80 |
+
Next time you can load the recovered weights directly from `<your_path_tuned>`.
|
81 |
+
"""
|
82 |
+
model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
|
83 |
+
path_raw,
|
84 |
+
device_map={"": torch.device(device)},
|
85 |
+
torch_dtype=torch.float32,
|
86 |
+
low_cpu_mem_usage=True,
|
87 |
+
)
|
88 |
+
model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
|
89 |
+
path_diff,
|
90 |
+
device_map={"": torch.device(device)},
|
91 |
+
torch_dtype=torch.float32,
|
92 |
+
low_cpu_mem_usage=True,
|
93 |
+
)
|
94 |
+
|
95 |
+
tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
96 |
+
path_raw
|
97 |
+
)
|
98 |
+
tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
99 |
+
path_diff
|
100 |
+
)
|
101 |
+
|
102 |
+
state_dict_recovered = model_recovered.state_dict()
|
103 |
+
state_dict_raw = model_raw.state_dict()
|
104 |
+
for key in tqdm.tqdm(state_dict_recovered):
|
105 |
+
state_dict_recovered[key].add_(state_dict_raw[key])
|
106 |
+
|
107 |
+
if path_tuned is not None:
|
108 |
+
model_recovered.save_pretrained(path_tuned)
|
109 |
+
tokenizer_recovered.save_pretrained(path_tuned)
|
110 |
+
|
111 |
+
if test_inference:
|
112 |
+
input_text = (
|
113 |
+
"Below is an instruction that describes a task. "
|
114 |
+
"Write a response that appropriately completes the request.\r\n\r\n"
|
115 |
+
"### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:"
|
116 |
+
)
|
117 |
+
inputs = tokenizer_recovered(input_text, return_tensors="pt")
|
118 |
+
out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100)
|
119 |
+
output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0]
|
120 |
+
output_text = output_text[len(input_text) :]
|
121 |
+
print(f"Input: {input_text}\nCompletion: {output_text}")
|
122 |
+
|
123 |
+
return model_recovered, tokenizer_recovered
|
124 |
+
|
125 |
+
|
126 |
+
def main(task, **kwargs):
|
127 |
+
globals()[task](**kwargs)
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
fire.Fire(main)
|