dmayhem93 commited on
Commit
3f3a6cb
1 Parent(s): 649b644

Upload apply_delta.py

Browse files
Files changed (1) hide show
  1. apply_delta.py +44 -0
apply_delta.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 apply_delta.py --base /path/to/model_weights/llama-65b --target-model-path stabilityai/FreeWilly1-Delta-SafeTensor --delta models/FreeWilly1-Delta-SafeTensor
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+
12
+ def apply_delta(base_model_path, target_model_path, delta_path):
13
+ print("Loading base model")
14
+ base = AutoModelForCausalLM.from_pretrained(
15
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
16
+
17
+ print("Loading delta")
18
+ delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
19
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
20
+
21
+ base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
22
+
23
+ input_embeddings = base.get_input_embeddings().weight.data
24
+ output_embeddings = base.get_output_embeddings().weight.data
25
+
26
+ print("Applying delta")
27
+ for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
28
+ assert name in delta.state_dict()
29
+ param.data += delta.state_dict()[name]
30
+
31
+ print("Saving target model")
32
+ base.save_pretrained(target_model_path)
33
+ delta_tokenizer.save_pretrained(target_model_path)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--base-model-path", type=str, required=True)
39
+ parser.add_argument("--target-model-path", type=str, required=True)
40
+ parser.add_argument("--delta-path", type=str, required=True)
41
+
42
+ args = parser.parse_args()
43
+
44
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)