imone commited on
Commit
e23cecb
2 Parent(s): 91c000f 0a61848

Merge branch 'main' of hf.co:imone/Llama-3-8B-fixed-special-embedding

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md CHANGED
@@ -3,3 +3,62 @@ license: other
3
  license_name: llama3
4
  license_link: LICENSE
5
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  license_name: llama3
4
  license_link: LICENSE
5
  ---
6
+
7
+ The original Llama 3 8b (base) special token weights are zero, which might cause NaN gradients. This version re-initialized the weights of all the following special tokens to alleviate the problem.
8
+
9
+ ```
10
+ <|eot_id|>
11
+ <|start_header_id|>
12
+ <|end_header_id|>
13
+ ```
14
+
15
+ We set the weights of these tokens in `embed` and `lm_head` to be the mean of all other tokens.
16
+
17
+ Code for making this model:
18
+
19
+ ```python
20
+ import argparse
21
+
22
+ import transformers
23
+ import torch
24
+
25
+
26
+ def init_eot_embedding_llama3(model_path, output_dir, special_tokens=["<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"], mean_cutoff=128000, dtype=torch.bfloat16):
27
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
28
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=dtype)
29
+
30
+ assert model.model.embed_tokens.weight.shape[0] >= mean_cutoff
31
+ assert model.lm_head.weight.shape[0] >= mean_cutoff
32
+
33
+ with torch.no_grad():
34
+ for token in special_tokens:
35
+ token_id = tokenizer.convert_tokens_to_ids(token)
36
+
37
+ print (f"Token {token} ID {token_id}")
38
+
39
+ model.model.embed_tokens.weight[token_id] = torch.mean(model.model.embed_tokens.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
40
+ model.lm_head.weight[token_id] = torch.mean(model.lm_head.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
41
+
42
+ # Save
43
+ tokenizer.save_pretrained(output_dir)
44
+ model.save_pretrained(output_dir)
45
+
46
+
47
+ def main():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument(
50
+ "--model-path",
51
+ help="Location of model, or HuggingFace repo ID",
52
+ )
53
+ parser.add_argument(
54
+ "--output-dir",
55
+ help="Location to write resulting model and tokenizer",
56
+ )
57
+
58
+ init_eot_embedding_llama3(**vars(parser.parse_args()))
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
63
+
64
+ ```