Text Generation
Transformers
PyTorch
English
llama
causal-lm
Inference Endpoints
text-generation-inference
pvduy commited on
Commit
90c0bd3
1 Parent(s): 20da0a6

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -0
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apply Delta weights
2
+
3
+ ```python
4
+
5
+ """
6
+ Usage:
7
+ python3 apply_delta.py --base /path/to/model_weights/llama-13b --target stable-vicuna-13b --delta pvduy/stable-vicuna-13b-delta
8
+ """
9
+ import argparse
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+
15
+
16
+ def apply_delta(base_model_path, target_model_path, delta_path):
17
+ print("Loading base model")
18
+ base = AutoModelForCausalLM.from_pretrained(
19
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+
21
+ print("Loading delta")
22
+ delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
23
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
24
+
25
+ DEFAULT_PAD_TOKEN = "[PAD]"
26
+ base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
27
+ num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
28
+
29
+ base.resize_token_embeddings(len(base_tokenizer))
30
+ input_embeddings = base.get_input_embeddings().weight.data
31
+ output_embeddings = base.get_output_embeddings().weight.data
32
+ input_embeddings[-num_new_tokens:] = 0
33
+ output_embeddings[-num_new_tokens:] = 0
34
+
35
+ print("Applying delta")
36
+ for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
37
+ assert name in delta.state_dict()
38
+ param.data += delta.state_dict()[name]
39
+
40
+ print("Saving target model")
41
+ base.save_pretrained(target_model_path)
42
+ delta_tokenizer.save_pretrained(target_model_path)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--base-model-path", type=str, required=True)
48
+ parser.add_argument("--target-model-path", type=str, required=True)
49
+ parser.add_argument("--delta-path", type=str, required=True)
50
+
51
+ args = parser.parse_args()
52
+
53
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
54
+ ```