Dynosaur commited on
Commit
ad1341d
1 Parent(s): cf8c3b6

Upload weight_diff.py

Browse files
Files changed (1) hide show
  1. weight_diff.py +131 -0
weight_diff.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/main/weight_diff.py
2
+
3
+ from typing import Optional
4
+
5
+ import fire
6
+ import torch
7
+ import tqdm
8
+ import transformers
9
+
10
+
11
+ @torch.inference_mode()
12
+ def make_diff(
13
+ path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu"
14
+ ):
15
+ """Make the weight diff.
16
+
17
+ This function is given to present full transparency of how the weight diff was created.
18
+
19
+ Run:
20
+ python weight_diff.py make_diff --path_raw <your_path_raw> --path_tuned <your_path_tuned> --path_diff <your_path_diff>
21
+ """
22
+ model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
23
+ path_tuned,
24
+ device_map={"": torch.device(device)},
25
+ torch_dtype=torch.float32,
26
+ low_cpu_mem_usage=True,
27
+ )
28
+ model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
29
+ path_raw,
30
+ device_map={"": torch.device(device)},
31
+ torch_dtype=torch.float32,
32
+ low_cpu_mem_usage=True,
33
+ )
34
+
35
+ tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
36
+ path_tuned
37
+ )
38
+ tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
39
+ path_raw
40
+ )
41
+
42
+ tokenizer_tuned.pad_token_id = (
43
+ 0 # unk. we want this to be different from the eos token
44
+ )
45
+ tokenizer_tuned.padding_side = "left" # Allow batched inference
46
+
47
+ state_dict_tuned = model_tuned.state_dict()
48
+ state_dict_raw = model_raw.state_dict()
49
+ for key in tqdm.tqdm(state_dict_tuned):
50
+ state_dict_tuned[key].add_(-state_dict_raw[key])
51
+
52
+ model_tuned.save_pretrained(path_diff)
53
+ tokenizer_tuned.save_pretrained(path_diff)
54
+
55
+
56
+ @torch.inference_mode()
57
+ def recover(
58
+ path_raw,
59
+ path_diff,
60
+ path_tuned: Optional[str] = None,
61
+ device="cpu",
62
+ test_inference=True
63
+ ):
64
+ """Recover the original weights from the released weight diff.
65
+
66
+ This function is given for you to run.
67
+
68
+ Things to do before running this:
69
+ 1. Convert Meta's released weights into huggingface format. Follow this guide:
70
+ https://huggingface.co/docs/transformers/main/model_doc/llama
71
+ You may refer to https://huggingface.co/huggyllama/llama-7b if you get some trouble in the conversion. (You should only use this repository if you have been granted access to the llama model.)
72
+ 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
73
+ https://huggingface.co/Dynosaur/dynosaur-llama-7b-superni
74
+ 3. Run this function with the correct paths. E.g.,
75
+ python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir>
76
+
77
+ Additional notes:
78
+ - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
79
+ - If you want to save the recovered weights, set `--path_tuned <your_path_tuned>`.
80
+ Next time you can load the recovered weights directly from `<your_path_tuned>`.
81
+ """
82
+ model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
83
+ path_raw,
84
+ device_map={"": torch.device(device)},
85
+ torch_dtype=torch.float32,
86
+ low_cpu_mem_usage=True,
87
+ )
88
+ model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
89
+ path_diff,
90
+ device_map={"": torch.device(device)},
91
+ torch_dtype=torch.float32,
92
+ low_cpu_mem_usage=True,
93
+ )
94
+
95
+ tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
96
+ path_raw
97
+ )
98
+ tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
99
+ path_diff
100
+ )
101
+
102
+ state_dict_recovered = model_recovered.state_dict()
103
+ state_dict_raw = model_raw.state_dict()
104
+ for key in tqdm.tqdm(state_dict_recovered):
105
+ state_dict_recovered[key].add_(state_dict_raw[key])
106
+
107
+ if path_tuned is not None:
108
+ model_recovered.save_pretrained(path_tuned)
109
+ tokenizer_recovered.save_pretrained(path_tuned)
110
+
111
+ if test_inference:
112
+ input_text = (
113
+ "Below is an instruction that describes a task. "
114
+ "Write a response that appropriately completes the request.\r\n\r\n"
115
+ "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:"
116
+ )
117
+ inputs = tokenizer_recovered(input_text, return_tensors="pt")
118
+ out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100)
119
+ output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0]
120
+ output_text = output_text[len(input_text) :]
121
+ print(f"Input: {input_text}\nCompletion: {output_text}")
122
+
123
+ return model_recovered, tokenizer_recovered
124
+
125
+
126
+ def main(task, **kwargs):
127
+ globals()[task](**kwargs)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ fire.Fire(main)