Ethan Shen commited on
Commit
dda1539
1 Parent(s): bc4858e

Initial commit

Browse files
Files changed (41) hide show
  1. .gitignore +3 -0
  2. LICENSE +126 -0
  3. app.py +97 -0
  4. params/g15_d3_mixed.json +27 -0
  5. params/g20_d3_mixed.json +27 -0
  6. params/g5_d3_mixed.json +27 -0
  7. params/p15_d10_mixed.json +26 -0
  8. params/p15_d2_mixed.json +26 -0
  9. params/p15_d3_mixed.json +26 -0
  10. params/p15_d3_ngram4_mixed.json +22 -0
  11. params/p15_d4_mixed.json +26 -0
  12. params/p15_d5_mixed.json +26 -0
  13. params/p15_d6_mixed.json +26 -0
  14. params/p25_d3_mixed.json +26 -0
  15. params/p40_d3_mixed.json +12 -0
  16. params/p5_d3_mixed.json +26 -0
  17. requirements.txt +11 -0
  18. superposed/llama/__init__.py +6 -0
  19. superposed/llama/__pycache__/__init__.cpython-312.pyc +0 -0
  20. superposed/llama/__pycache__/generation.cpython-312.pyc +0 -0
  21. superposed/llama/__pycache__/model.cpython-312.pyc +0 -0
  22. superposed/llama/__pycache__/superpose.cpython-312.pyc +0 -0
  23. superposed/llama/__pycache__/superposed_generation.cpython-312.pyc +0 -0
  24. superposed/llama/__pycache__/superposed_model.cpython-312.pyc +0 -0
  25. superposed/llama/__pycache__/tokenizer.cpython-312.pyc +0 -0
  26. superposed/llama/__pycache__/utils.cpython-312.pyc +0 -0
  27. superposed/llama/generation.py +268 -0
  28. superposed/llama/metrics.py +109 -0
  29. superposed/llama/model.py +548 -0
  30. superposed/llama/superpose.py +328 -0
  31. superposed/llama/superposed_generation.py +198 -0
  32. superposed/llama/superposed_model.py +515 -0
  33. superposed/llama/tokenizer.py +68 -0
  34. superposed/llama/utils.py +70 -0
  35. superposed/ngrams/__pycache__/ngram_models.cpython-312.pyc +0 -0
  36. superposed/ngrams/make_corpus.py +268 -0
  37. superposed/ngrams/ngram_models.py +115 -0
  38. superposed/ngrams/test.json +8 -0
  39. superposed/notebooks/custom.ipynb +289 -0
  40. superposed/notebooks/nq.ipynb +417 -0
  41. superposed/notebooks/triviaqa.ipynb +404 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ weights
3
+ ckpts-200k
LICENSE ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLAMA 2 COMMUNITY LICENSE AGREEMENT
2
+ Llama 2 Version Release Date: July 18, 2023
3
+
4
+ "Agreement" means the terms and conditions for use, reproduction, distribution and
5
+ modification of the Llama Materials set forth herein.
6
+
7
+ "Documentation" means the specifications, manuals and documentation
8
+ accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
9
+ libraries/llama-downloads/.
10
+
11
+ "Licensee" or "you" means you, or your employer or any other person or entity (if
12
+ you are entering into this Agreement on such person or entity's behalf), of the age
13
+ required under applicable laws, rules or regulations to provide legal consent and that
14
+ has legal authority to bind your employer or such other person or entity if you are
15
+ entering in this Agreement on their behalf.
16
+
17
+ "Llama 2" means the foundational large language models and software and
18
+ algorithms, including machine-learning model code, trained model weights,
19
+ inference-enabling code, training-enabling code, fine-tuning enabling code and other
20
+ elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
21
+ libraries/llama-downloads/.
22
+
23
+ "Llama Materials" means, collectively, Meta's proprietary Llama 2 and
24
+ Documentation (and any portion thereof) made available under this Agreement.
25
+
26
+ "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
27
+ are an entity, your principal place of business is in the EEA or Switzerland) and Meta
28
+ Platforms, Inc. (if you are located outside of the EEA or Switzerland).
29
+
30
+ By clicking "I Accept" below or by using or distributing any portion or element of the
31
+ Llama Materials, you agree to be bound by this Agreement.
32
+
33
+ 1. License Rights and Redistribution.
34
+
35
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
36
+ transferable and royalty-free limited license under Meta's intellectual property or
37
+ other rights owned by Meta embodied in the Llama Materials to use, reproduce,
38
+ distribute, copy, create derivative works of, and make modifications to the Llama
39
+ Materials.
40
+
41
+ b. Redistribution and Use.
42
+
43
+ i. If you distribute or make the Llama Materials, or any derivative works
44
+ thereof, available to a third party, you shall provide a copy of this Agreement to such
45
+ third party.
46
+ ii. If you receive Llama Materials, or any derivative works thereof, from
47
+ a Licensee as part of an integrated end user product, then Section 2 of this
48
+ Agreement will not apply to you.
49
+
50
+ iii. You must retain in all copies of the Llama Materials that you
51
+ distribute the following attribution notice within a "Notice" text file distributed as a
52
+ part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
53
+ Copyright (c) Meta Platforms, Inc. All Rights Reserved."
54
+
55
+ iv. Your use of the Llama Materials must comply with applicable laws
56
+ and regulations (including trade compliance laws and regulations) and adhere to the
57
+ Acceptable Use Policy for the Llama Materials (available at
58
+ https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
59
+ this Agreement.
60
+
61
+ v. You will not use the Llama Materials or any output or results of the
62
+ Llama Materials to improve any other large language model (excluding Llama 2 or
63
+ derivative works thereof).
64
+
65
+ 2. Additional Commercial Terms. If, on the Llama 2 version release date, the
66
+ monthly active users of the products or services made available by or for Licensee,
67
+ or Licensee's affiliates, is greater than 700 million monthly active users in the
68
+ preceding calendar month, you must request a license from Meta, which Meta may
69
+ grant to you in its sole discretion, and you are not authorized to exercise any of the
70
+ rights under this Agreement unless or until Meta otherwise expressly grants you
71
+ such rights.
72
+
73
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
74
+ LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
75
+ PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
76
+ EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
77
+ WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
78
+ FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
79
+ FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
80
+ THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
81
+ USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
82
+
83
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
84
+ LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
85
+ NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
86
+ AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
87
+ CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
88
+ IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
89
+ ANY OF THE FOREGOING.
90
+
91
+ 5. Intellectual Property.
92
+
93
+ a. No trademark licenses are granted under this Agreement, and in
94
+ connection with the Llama Materials, neither Meta nor Licensee may use any name
95
+ or mark owned by or associated with the other or any of its affiliates, except as
96
+ required for reasonable and customary use in describing and redistributing the
97
+ Llama Materials.
98
+
99
+ b. Subject to Meta's ownership of Llama Materials and derivatives made by or
100
+ for Meta, with respect to any derivative works and modifications of the Llama
101
+ Materials that are made by you, as between you and Meta, you are and will be the
102
+ owner of such derivative works and modifications.
103
+
104
+ c. If you institute litigation or other proceedings against Meta or any entity
105
+ (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
106
+ Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
107
+ constitutes an infringement of intellectual property or other rights owned or licensable
108
+ by you, then any licenses granted to you under this Agreement shall terminate as of
109
+ the date such litigation or claim is filed or instituted. You will indemnify and hold
110
+ harmless Meta from and against any claim by any third party arising out of or related
111
+ to your use or distribution of the Llama Materials.
112
+
113
+ 6. Term and Termination. The term of this Agreement will commence upon your
114
+ acceptance of this Agreement or access to the Llama Materials and will continue in
115
+ full force and effect until terminated in accordance with the terms and conditions
116
+ herein. Meta may terminate this Agreement if you are in breach of any term or
117
+ condition of this Agreement. Upon termination of this Agreement, you shall delete
118
+ and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
119
+ termination of this Agreement.
120
+
121
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and
122
+ construed under the laws of the State of California without regard to choice of law
123
+ principles, and the UN Convention on Contracts for the International Sale of Goods
124
+ does not apply to this Agreement. The courts of California shall have exclusive
125
+ jurisdiction of any dispute arising out of this Agreement.
126
+
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import spaces
5
+ import torch
6
+
7
+ from dotenv import load_dotenv
8
+ from huggingface_hub import login, snapshot_download
9
+
10
+ from superposed.llama.superposed_generation import SuperposedLlama
11
+ from superposed.llama.tokenizer import Tokenizer
12
+ from superposed.ngrams.ngram_models import make_models
13
+
14
+ # load_dotenv()
15
+ # print(os.getenv("HF_ACCESS_TOKEN"))
16
+ login(os.getenv("HF_ACCESS_TOKEN"))
17
+ if not os.path.exists("./weights/"):
18
+ os.mkdir("./weights/")
19
+ snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/")
20
+ weight_path = "./weights/"
21
+ # Load params
22
+ param_file = "params/p15_d3_mixed.json"
23
+ with open(param_file, "r") as f:
24
+ params = json.load(f)
25
+ alpha = params["alpha"]
26
+ temp = params["temp"]
27
+ n_drafts = params["n_drafts"]
28
+ prompt_len = params["prompt_len"]
29
+ n_token_sample = params["n_token_sample"]
30
+ i_weights = params["i_weights"]
31
+ i_length = params["i_length"]
32
+ # Load main model
33
+ model = SuperposedLlama.build(ckpt_dir=weight_path,
34
+ tokenizer_path=f'{weight_path}/tokenizer.model',
35
+ max_seq_len=100,
36
+ max_batch_size=32,
37
+ model_parallel_size=1)
38
+ tokenizer = Tokenizer(f'{weight_path}/tokenizer.model')
39
+ # Create ngram models
40
+ ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)
41
+
42
+ def decode(tokenizer, encoding):
43
+ """
44
+ Args:
45
+ tokenizer (Any): Tokenizer
46
+ encoding (torch.Tensor): Encoding
47
+ Returns:
48
+ decoding (str)
49
+ """
50
+ eos_locs = (encoding == tokenizer.eos_id).nonzero()
51
+ if len(eos_locs > 0):
52
+ encoding = encoding[:eos_locs[0]]
53
+ return tokenizer.decode(encoding.to(torch.int32).tolist())
54
+
55
+ @spaces.GPU
56
+ def update_options(input, num_tokens):
57
+ tokenized_prompts = tokenizer.encode([input], True, False)
58
+ alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts,
59
+ smoothing="geom",
60
+ max_gen_len=num_tokens,
61
+ n_token_sample=n_token_sample,
62
+ alpha=alpha,
63
+ temp=temp,
64
+ n_drafts=n_drafts,
65
+ i_weights=i_weights,
66
+ i_length=i_length,
67
+ ngrams=ngrams,
68
+ get_time=False,
69
+ penalty=200)
70
+ gens = alive_gens[0].reshape(n_drafts, -1)
71
+ return decode(tokenizer, gens[0]), decode(tokenizer, gens[1]), decode(tokenizer, gens[2])
72
+
73
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
74
+ gr.Markdown(
75
+ """
76
+ # Superposed Decoding
77
+ Start typing below to see suggestions.
78
+ """)
79
+ slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10)
80
+ inp = gr.Textbox(placeholder="Type anything!", lines=3)
81
+ option1 = gr.Button(value="Option 1")
82
+ option2 = gr.Button(value="Option 2")
83
+ option3 = gr.Button(value="Option 3")
84
+ inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3])
85
+ # Button updates
86
+ @option1.click(inputs=[inp, option1], outputs=inp)
87
+ def option1_click(curr, txt):
88
+ return curr + txt
89
+ @option2.click(inputs=[inp, option2], outputs=inp)
90
+ def option2_click(curr, txt):
91
+ return curr + txt
92
+ @option3.click(inputs=[inp, option3], outputs=inp)
93
+ def option3_click(curr, txt):
94
+ return curr + txt
95
+
96
+ if __name__ == "__main__":
97
+ demo.launch(debug=True)
params/g15_d3_mixed.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.48,
3
+ "temp": 0.06,
4
+ "n_drafts": 3,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 15,
7
+ "max_gen_len": 15,
8
+ "n_token_consider": 32000,
9
+ "mixing_method": "sample_new_weights_with_score",
10
+ "smoothing": "geom",
11
+ "sample_tokens": 0,
12
+ "sample_beams": 0,
13
+ "i_weights": [
14
+ 0.01,
15
+ 0.04,
16
+ 0.15,
17
+ 0.18,
18
+ 0.12
19
+ ],
20
+ "i_length": [
21
+ 1,
22
+ 2,
23
+ 3,
24
+ 4,
25
+ 5
26
+ ]
27
+ }
params/g20_d3_mixed.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.5,
3
+ "temp": 0.04,
4
+ "n_drafts": 3,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 15,
7
+ "max_gen_len": 20,
8
+ "n_token_consider": 32000,
9
+ "mixing_method": "sample_new_weights_with_score",
10
+ "smoothing": "geom",
11
+ "sample_tokens": 0,
12
+ "sample_beams": 0,
13
+ "i_weights": [
14
+ 0.01,
15
+ 0.04,
16
+ 0.15,
17
+ 0.18,
18
+ 0.12
19
+ ],
20
+ "i_length": [
21
+ 1,
22
+ 2,
23
+ 3,
24
+ 4,
25
+ 5
26
+ ]
27
+ }
params/g5_d3_mixed.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.52,
3
+ "temp": 0.06,
4
+ "n_drafts": 3,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 15,
7
+ "max_gen_len": 5,
8
+ "n_token_consider": 32000,
9
+ "mixing_method": "sample_new_weights_with_score",
10
+ "smoothing": "geom",
11
+ "sample_tokens": 0,
12
+ "sample_beams": 0,
13
+ "i_weights": [
14
+ 0.01,
15
+ 0.04,
16
+ 0.15,
17
+ 0.18,
18
+ 0.12
19
+ ],
20
+ "i_length": [
21
+ 1,
22
+ 2,
23
+ 3,
24
+ 4,
25
+ 5
26
+ ]
27
+ }
params/p15_d10_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.54,
3
+ "temp": 0.12,
4
+ "n_drafts": 10,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 30,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p15_d2_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.62,
3
+ "temp": 0.06,
4
+ "n_drafts": 2,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 6,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p15_d3_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.54,
3
+ "temp": 0.06,
4
+ "n_drafts": 3,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 9,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p15_d3_ngram4_mixed.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.55,
3
+ "temp": 0.1,
4
+ "n_drafts": 3,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 9,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15
16
+ ],
17
+ "i_length": [
18
+ 1,
19
+ 2,
20
+ 3
21
+ ]
22
+ }
params/p15_d4_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.52,
3
+ "temp": 0.06,
4
+ "n_drafts": 4,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 12,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p15_d5_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.6,
3
+ "temp": 0.06,
4
+ "n_drafts": 5,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 15,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p15_d6_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.52,
3
+ "temp": 0.06,
4
+ "n_drafts": 6,
5
+ "prompt_len": 15,
6
+ "n_token_sample": 18,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p25_d3_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.5,
3
+ "temp": 0.12,
4
+ "n_drafts": 3,
5
+ "prompt_len": 25,
6
+ "n_token_sample": 15,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
params/p40_d3_mixed.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.55,
3
+ "temp": 0.1,
4
+ "prompt_len": 40,
5
+ "mixing_method": "sample_new_weights_with_score",
6
+ "smoothing": "geom",
7
+ "sample_tokens": 0,
8
+ "sample_beams": 0,
9
+ "i_weights": [0.01, 0.04, 0.15, 0.18, 0.12],
10
+ "i_length": [1, 2, 3, 4, 5],
11
+ "ckpt_path": "../ckpts-200k"
12
+ }
params/p5_d3_mixed.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.34,
3
+ "temp": 0.12,
4
+ "n_drafts": 3,
5
+ "prompt_len": 5,
6
+ "n_token_sample": 15,
7
+ "n_token_consider": 32000,
8
+ "mixing_method": "sample_new_weights_with_score",
9
+ "smoothing": "geom",
10
+ "sample_tokens": 0,
11
+ "sample_beams": 0,
12
+ "i_weights": [
13
+ 0.01,
14
+ 0.04,
15
+ 0.15,
16
+ 0.18,
17
+ 0.12
18
+ ],
19
+ "i_length": [
20
+ 1,
21
+ 2,
22
+ 3,
23
+ 4,
24
+ 5
25
+ ]
26
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.19.0
2
+ fairscale==0.4.13
3
+ loguru==0.7.2
4
+ nltk==3.8.1
5
+ numpy==1.26.4
6
+ Requests==2.32.2
7
+ sentencepiece==0.2.0
8
+ setuptools==58.2.0
9
+ torch==2.3.0
10
+ tqdm==4.66.4
11
+ transformers==4.37.2
superposed/llama/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from .generation import Llama, Dialog
5
+ from .model import ModelArgs, Transformer
6
+ from .tokenizer import Tokenizer
superposed/llama/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (335 Bytes). View file
 
superposed/llama/__pycache__/generation.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
superposed/llama/__pycache__/model.cpython-312.pyc ADDED
Binary file (26.7 kB). View file
 
superposed/llama/__pycache__/superpose.cpython-312.pyc ADDED
Binary file (19.1 kB). View file
 
superposed/llama/__pycache__/superposed_generation.cpython-312.pyc ADDED
Binary file (10.1 kB). View file
 
superposed/llama/__pycache__/superposed_model.cpython-312.pyc ADDED
Binary file (25.9 kB). View file
 
superposed/llama/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (3.26 kB). View file
 
superposed/llama/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.97 kB). View file
 
superposed/llama/generation.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import json
5
+ import os
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+ from typing import List, Literal, Optional, Tuple, TypedDict
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairscale.nn.model_parallel.initialize import (
14
+ get_model_parallel_rank,
15
+ initialize_model_parallel,
16
+ model_parallel_is_initialized,
17
+ )
18
+
19
+ from superposed.llama.model import ModelArgs, Transformer
20
+ from superposed.llama.tokenizer import Tokenizer
21
+ from superposed.llama.utils import *
22
+
23
+ Role = Literal["system", "user", "assistant"]
24
+
25
+
26
+ class Message(TypedDict):
27
+ role: Role
28
+ content: str
29
+
30
+
31
+ class CompletionPrediction(TypedDict, total=False):
32
+ generation: str
33
+ tokens: List[str] # not required
34
+ logprobs: List[float] # not required
35
+
36
+
37
+ class ChatPrediction(TypedDict, total=False):
38
+ generation: Message
39
+ tokens: List[str] # not required
40
+ logprobs: List[float] # not required
41
+
42
+
43
+ Dialog = List[Message]
44
+
45
+ B_INST, E_INST = "[INST]", "[/INST]"
46
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
47
+
48
+ SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
49
+ UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
50
+
51
+
52
+ class Llama:
53
+ @staticmethod
54
+ def build(
55
+ ckpt_dir: str,
56
+ tokenizer_path: str,
57
+ max_seq_len: int,
58
+ max_batch_size: int,
59
+ device: None,
60
+ model_parallel_size: Optional[int] = None,
61
+ seed: int = 1,
62
+ ) -> "Llama":
63
+ """
64
+ Build a Llama instance by initializing and loading a pre-trained model.
65
+
66
+ Args:
67
+ ckpt_dir (str): Path to the directory containing checkpoint files.
68
+ tokenizer_path (str): Path to the tokenizer file.
69
+ max_seq_len (int): Maximum sequence length for input text.
70
+ max_batch_size (int): Maximum batch size for inference.
71
+ mixed (bool): Whether to mix embeddings or not
72
+ model_parallel_size (Optional[int], optional): Number of model parallel processes.
73
+ If not provided, it's determined from the environment. Defaults to None.
74
+
75
+ Returns:
76
+ Llama: An instance of the Llama class with the loaded model and tokenizer.
77
+
78
+ Raises:
79
+ AssertionError: If there are no checkpoint files in the specified directory,
80
+ or if the model parallel size does not match the number of checkpoint files.
81
+
82
+ Note:
83
+ This method initializes the distributed process group, sets the device to CUDA,
84
+ and loads the pre-trained model and tokenizer.
85
+
86
+ """
87
+ if not torch.distributed.is_initialized():
88
+ torch.distributed.init_process_group("nccl")
89
+ if not model_parallel_is_initialized():
90
+ if model_parallel_size is None:
91
+ model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
92
+ initialize_model_parallel(model_parallel_size)
93
+
94
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
95
+ print(local_rank)
96
+ # torch.cuda.set_device(local_rank)
97
+ if device == None:
98
+ torch.cuda.set_device(local_rank)
99
+ device = f"cuda:{local_rank}"
100
+ # seed must be the same in all processes
101
+ torch.manual_seed(seed)
102
+
103
+ if local_rank > 0:
104
+ sys.stdout = open(os.devnull, "w")
105
+
106
+ start_time = time.time()
107
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
108
+ assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
109
+ assert model_parallel_size == len(
110
+ checkpoints
111
+ ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
112
+ ckpt_path = checkpoints[get_model_parallel_rank()]
113
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
114
+ with open(Path(ckpt_dir) / "params.json", "r") as f:
115
+ params = json.loads(f.read())
116
+
117
+ model_args: ModelArgs = ModelArgs(
118
+ max_seq_len=max_seq_len,
119
+ max_batch_size=max_batch_size,
120
+ **params,
121
+ )
122
+ tokenizer = Tokenizer(model_path=tokenizer_path)
123
+ model_args.vocab_size = tokenizer.n_words
124
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
125
+ model = Transformer(model_args)
126
+ model.load_state_dict(checkpoint, strict=False)
127
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
128
+ return Llama(model, tokenizer, device)
129
+
130
+ def __init__(self, model: Transformer, tokenizer: Tokenizer, device):
131
+ self.model = model.to(device).eval()
132
+ self.tokenizer = tokenizer
133
+ self.device = device
134
+
135
+ @torch.inference_mode()
136
+ def generate(
137
+ self,
138
+ prompt_tokens: List[List[int]],
139
+ max_gen_len: int,
140
+ temperature: float = 0.6,
141
+ top_p: float = 0.9,
142
+ logprobs: bool = True,
143
+ grade: bool = False
144
+ ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
145
+ """
146
+ Generate text sequences based on provided prompts using the language generation model.
147
+
148
+ Args:
149
+ prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
150
+ max_gen_len (int): Maximum length of the generated text sequence.
151
+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
152
+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
153
+ logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
154
+ echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
155
+
156
+ Returns:
157
+ Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
158
+
159
+ Note:
160
+ This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
161
+ If logprobs is True, token log probabilities are computed for each generated token.
162
+
163
+ """
164
+ params = self.model.params
165
+ bsz = len(prompt_tokens)
166
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
167
+
168
+ min_prompt_len = min(len(t) for t in prompt_tokens)
169
+ max_prompt_len = max(len(t) for t in prompt_tokens)
170
+ # assert min_prompt_len == max_prompt_len
171
+ prompt_len = min_prompt_len
172
+ assert max_prompt_len <= params.max_seq_len
173
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
174
+
175
+ pad_id = self.tokenizer.pad_id
176
+ tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
177
+ for k, t in enumerate(prompt_tokens):
178
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
179
+ if logprobs:
180
+ token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
181
+ prev_pos = 0
182
+ eos_reached = torch.tensor([False] * bsz, device=self.device)
183
+ input_text_mask = tokens != pad_id
184
+ if grade:
185
+ pad_mask = tokens == pad_id
186
+ tokens = torch.where(tokens == pad_id, 0, tokens)
187
+ logits = self.model.forward(tokens, prev_pos, False)
188
+ tokens[pad_mask] = pad_id
189
+ token_logprobs = -F.cross_entropy(
190
+ input=logits[:, :-1, :].transpose(1, 2),
191
+ target=tokens[:, 1:],
192
+ reduction="none",
193
+ ignore_index=pad_id,
194
+ )
195
+ #if pad_id in tokens:
196
+ # print(pad_id)
197
+ # print(tokens)
198
+ # print(token_logprobs)
199
+ return token_logprobs
200
+
201
+ for cur_pos in range(min_prompt_len, total_len):
202
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, False)
203
+ if temperature > 0:
204
+ probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
205
+ next_token = sample_top_p(probs, top_p)
206
+ else:
207
+ next_token = torch.argmax(logits[:, -1], dim=-1)
208
+
209
+ next_token = next_token.reshape(-1)
210
+ # only replace token if prompt has already been generated
211
+ next_token = torch.where(
212
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
213
+ )
214
+ tokens[:, cur_pos] = next_token
215
+ if logprobs:
216
+ token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
217
+ input=logits.transpose(1, 2),
218
+ target=tokens[:, prev_pos + 1 : cur_pos + 1],
219
+ reduction="none",
220
+ ignore_index=pad_id,
221
+ )
222
+ eos_reached |= (~input_text_mask[:, cur_pos]) & (
223
+ next_token == self.tokenizer.eos_id
224
+ )
225
+ prev_pos = cur_pos
226
+ if all(eos_reached):
227
+ break
228
+
229
+ # seq_len = torch.sum(tokens != pad_id, dim=1)
230
+ # return tokens, torch.exp(-1 * torch.sum(logprobs, dim=1) / (seq_len - prompt_len)), torch.exp(-1 * torch.sum(custom_logprobs, dim=1) / )
231
+ if logprobs:
232
+ token_logprobs = token_logprobs.tolist()
233
+
234
+ out_ppl = []
235
+ for i, toks in enumerate(tokens.tolist()):
236
+ if logprobs:
237
+ probs = token_logprobs[i][prompt_len : len(prompt_tokens[i]) + max_gen_len]
238
+ # cut to eos tok if any
239
+ if self.tokenizer.eos_id in toks:
240
+ eos_idx = toks.index(self.tokenizer.eos_id)
241
+ probs = probs[:eos_idx] if logprobs else None
242
+ out_ppl.append(torch.exp(-1 * torch.sum(torch.tensor(probs)) / len(probs)))
243
+ return tokens, torch.tensor(out_ppl) if logprobs else None
244
+
245
+ def sample_top_p(probs, p, s=1):
246
+ """
247
+ Perform top-p (nucleus) sampling on a probability distribution.
248
+
249
+ Args:
250
+ probs (torch.Tensor): Probability distribution tensor.
251
+ p (float): Probability threshold for top-p sampling.
252
+
253
+ Returns:
254
+ torch.Tensor: Sampled token indices.
255
+
256
+ Note:
257
+ Top-p sampling selects the smallest set of tokens whose cumulative probability mass
258
+ exceeds the threshold p. The distribution is renormalized based on the selected tokens.
259
+
260
+ """
261
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
262
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
263
+ mask = probs_sum - probs_sort > p
264
+ probs_sort[mask] = 0.0
265
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
266
+ next_token = torch.multinomial(probs_sort, num_samples=s)
267
+ next_token = torch.gather(probs_idx, -1, next_token)
268
+ return next_token
superposed/llama/metrics.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+ from nltk.translate.bleu_score import SmoothingFunction
4
+ from tqdm import tqdm
5
+
6
+ def calculate_perplexity(model, tokens, prompt_len, bsz=1, marker=False):
7
+ """
8
+ Calculate perplexity of given tokens using provided model, ignoring padding tokens.
9
+ Args:
10
+ model: Llama model
11
+ tokens (List[List[int]] or torch.Tensor): Input tokens (n_prompt * n_draft, seqlen)
12
+ prompt_len (int): Prefix length
13
+ bsz (int): Batch size
14
+ marker (bool): Whether to show progress bar
15
+ Returns:
16
+ Perplexity across all generations (n_prompt * n_drafts)
17
+ """
18
+ it = range(0, len(tokens), bsz)
19
+ if marker:
20
+ it = tqdm(it)
21
+ start = 0
22
+ ppl = torch.zeros(len(tokens))
23
+ for start in it:
24
+ end = start + bsz
25
+ data = tokens[start : end]
26
+ if not isinstance(data, list):
27
+ data = data.tolist()
28
+ # Remove any padding tokens (-1) in generations
29
+ for d_idx in range(len(data)):
30
+ cur = data[d_idx]
31
+ if -1 in cur:
32
+ data[d_idx] = cur[:cur.index(-1)]
33
+ # Calculate cross entropy loss on tokens
34
+ ce_loss = model.generate(data, max_gen_len=0, temperature=-1, top_p=-1, grade=True)
35
+ # Cut off everything past `prompt_len`
36
+ ce_loss = ce_loss[:, prompt_len-1:] # Subtract 1 because the first token (start token) is removed
37
+ # Calculate perplexity
38
+ lengths = (ce_loss != 0).sum(dim=-1)
39
+ mean = ce_loss.sum(dim=-1) / lengths
40
+ ppl[start : end] = torch.exp(-1 * mean)
41
+ return ppl
42
+
43
+ def calculate_diversity(generations, k=4):
44
+ """
45
+ Calculate diversity of generations using SELF-BLEU.
46
+ Args:
47
+ generations (List[List[List[int]]]): Tokenized input
48
+ k (int, Optional): Number of n-grams to use for bleu
49
+ Returns:
50
+ Average diversity across all generations (float)
51
+ """
52
+ nltk.download('punkt') # Can be deleted once downloaded
53
+ smooth = SmoothingFunction()
54
+ bleus = []
55
+
56
+ for drafts in generations:
57
+ tokenized_drafts = []
58
+ # Stringify tokens
59
+ for d in drafts:
60
+ if -1 in d:
61
+ d = d[:d.index(-1)]
62
+ tokenized_drafts.append([str(n) for n in d])
63
+ # Calculate SELF-BLEU
64
+ minlength = min([len(g) for g in tokenized_drafts])
65
+ minlength = min(minlength, k)
66
+ weights = tuple((1. / minlength for _ in range(minlength)))
67
+ for i in range(len(drafts)):
68
+ # Create source and reference (all other drafts)
69
+ src = tokenized_drafts[i]
70
+ ref = tokenized_drafts[:i] + tokenized_drafts[i+1:]
71
+ tmp = nltk.translate.bleu_score.sentence_bleu(references=ref,
72
+ hypothesis=src,
73
+ weights=weights,
74
+ smoothing_function=smooth.method1)
75
+ bleus.append(tmp)
76
+ bleus = torch.Tensor(bleus)
77
+ return torch.mean(bleus)
78
+
79
+
80
+ def calculate_ngram_repetition(sequences):
81
+ """
82
+ Calculate uniqueness scores of `sequences`.
83
+ Args:
84
+ sequences (List[List[int]]): Generated sequences
85
+ Returns:
86
+ (unigram_uniqueness, bigram_uniqueness, trigram_uniqueness)
87
+ """
88
+ u_total = 0
89
+ b_total = 0
90
+ t_total = 0
91
+ # Iterate through all sequences indiscriminately
92
+ for gen in sequences:
93
+ if -1 in gen:
94
+ gen = gen[:gen.index(-1)]
95
+ unigrams, bigrams, trigrams = [], [], []
96
+ o = [str(i) for i in gen]
97
+ # Create lists of n-grams for the generation
98
+ for i in range(len(o)):
99
+ unigrams.append(o[i])
100
+ for i in range(len(o) - 1):
101
+ bigrams.append(o[i] + '_' + o[i + 1])
102
+ for i in range(len(o) - 2):
103
+ trigrams.append(o[i] + '_' + o[i + 1] + '_' + o[i + 2])
104
+ # Calculate uniqueness of the generation
105
+ u, b, t = len(set(unigrams)) / len(unigrams), len(set(bigrams)) / len(bigrams), len(set(trigrams)) / len(trigrams)
106
+ u_total += u
107
+ b_total += b
108
+ t_total += t
109
+ return u_total / len(sequences), b_total / len(sequences), t_total / len(sequences)
superposed/llama/model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ import fairscale.nn.model_parallel.initialize as fs_init
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairscale.nn.model_parallel.layers import (
12
+ ColumnParallelLinear,
13
+ ParallelEmbedding,
14
+ RowParallelLinear,
15
+ )
16
+ from torch import nn
17
+
18
+
19
+ @dataclass
20
+ class ModelArgs:
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+
30
+ max_batch_size: int = 32
31
+ max_seq_len: int = 2048
32
+
33
+
34
+ class RMSNorm(torch.nn.Module):
35
+ def __init__(self, dim: int, eps: float = 1e-6):
36
+ """
37
+ Initialize the RMSNorm normalization layer.
38
+
39
+ Args:
40
+ dim (int): The dimension of the input tensor.
41
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
42
+
43
+ Attributes:
44
+ eps (float): A small value added to the denominator for numerical stability.
45
+ weight (nn.Parameter): Learnable scaling parameter.
46
+
47
+ """
48
+ super().__init__()
49
+ self.eps = eps
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+
52
+ def _norm(self, x):
53
+ """
54
+ Apply the RMSNorm normalization to the input tensor.
55
+
56
+ Args:
57
+ x (torch.Tensor): The input tensor.
58
+
59
+ Returns:
60
+ torch.Tensor: The normalized tensor.
61
+
62
+ """
63
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
64
+
65
+ def forward(self, x):
66
+ """
67
+ Forward pass through the RMSNorm layer.
68
+
69
+ Args:
70
+ x (torch.Tensor): The input tensor.
71
+
72
+ Returns:
73
+ torch.Tensor: The output tensor after applying RMSNorm.
74
+
75
+ """
76
+ output = self._norm(x.float()).type_as(x)
77
+ return output * self.weight
78
+
79
+
80
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
81
+ """
82
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
83
+
84
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
85
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
86
+ The returned tensor contains complex values in complex64 data type.
87
+
88
+ Args:
89
+ dim (int): Dimension of the frequency tensor.
90
+ end (int): End index for precomputing frequencies.
91
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
92
+
93
+ Returns:
94
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
95
+
96
+
97
+
98
+
99
+ """
100
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
101
+ t = torch.arange(end, device=freqs.device) # type: ignore
102
+ freqs = torch.outer(t, freqs).float() # type: ignore
103
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
104
+ return freqs_cis
105
+
106
+
107
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
108
+ """
109
+ Reshape frequency tensor for broadcasting it with another tensor.
110
+
111
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
112
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
113
+
114
+ Args:
115
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
116
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
117
+
118
+ Returns:
119
+ torch.Tensor: Reshaped frequency tensor.
120
+
121
+ Raises:
122
+ AssertionError: If the frequency tensor doesn't match the expected shape.
123
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
124
+ """
125
+ ndim = x.ndim
126
+ assert 0 <= 1 < ndim
127
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
128
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
129
+ return freqs_cis.view(*shape)
130
+
131
+
132
+ def apply_rotary_emb(
133
+ xq: torch.Tensor,
134
+ xk: torch.Tensor,
135
+ freqs_cis: torch.Tensor,
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """
138
+ Apply rotary embeddings to input tensors using the given frequency tensor.
139
+
140
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
141
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
142
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
143
+ returned as real tensors.
144
+
145
+ Args:
146
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
147
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
148
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
149
+
150
+ Returns:
151
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
152
+
153
+
154
+
155
+ """
156
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
157
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
158
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
159
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
160
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
161
+ return xq_out.type_as(xq), xk_out.type_as(xk)
162
+
163
+
164
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
165
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
166
+ bs, slen, n_kv_heads, head_dim = x.shape
167
+ if n_rep == 1:
168
+ return x
169
+ return (
170
+ x[:, :, :, None, :]
171
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
172
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
173
+ )
174
+
175
+
176
+ class Attention(nn.Module):
177
+ """Multi-head attention module."""
178
+ def __init__(self, args: ModelArgs):
179
+ """
180
+ Initialize the Attention module.
181
+
182
+ Args:
183
+ args (ModelArgs): Model configuration parameters.
184
+
185
+ Attributes:
186
+ n_kv_heads (int): Number of key and value heads.
187
+ n_local_heads (int): Number of local query heads.
188
+ n_local_kv_heads (int): Number of local key and value heads.
189
+ n_rep (int): Number of repetitions for local heads.
190
+ head_dim (int): Dimension size of each attention head.
191
+ wq (ColumnParallelLinear): Linear transformation for queries.
192
+ wk (ColumnParallelLinear): Linear transformation for keys.
193
+ wv (ColumnParallelLinear): Linear transformation for values.
194
+ wo (RowParallelLinear): Linear transformation for output.
195
+ cache_k (torch.Tensor): Cached keys for attention.
196
+ cache_v (torch.Tensor): Cached values for attention.
197
+
198
+ """
199
+ super().__init__()
200
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
201
+ model_parallel_size = fs_init.get_model_parallel_world_size()
202
+ self.n_local_heads = args.n_heads // model_parallel_size
203
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
204
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
205
+ self.head_dim = args.dim // args.n_heads
206
+
207
+ self.wq = ColumnParallelLinear(
208
+ args.dim,
209
+ args.n_heads * self.head_dim,
210
+ bias=False,
211
+ gather_output=False,
212
+ init_method=lambda x: x,
213
+ )
214
+ self.wk = ColumnParallelLinear(
215
+ args.dim,
216
+ self.n_kv_heads * self.head_dim,
217
+ bias=False,
218
+ gather_output=False,
219
+ init_method=lambda x: x,
220
+ )
221
+ self.wv = ColumnParallelLinear(
222
+ args.dim,
223
+ self.n_kv_heads * self.head_dim,
224
+ bias=False,
225
+ gather_output=False,
226
+ init_method=lambda x: x,
227
+ )
228
+ self.wo = RowParallelLinear(
229
+ args.n_heads * self.head_dim,
230
+ args.dim,
231
+ bias=False,
232
+ input_is_parallel=True,
233
+ init_method=lambda x: x,
234
+ )
235
+
236
+ self.cache_k = torch.zeros(
237
+ (
238
+ args.max_batch_size,
239
+ args.max_seq_len,
240
+ self.n_local_kv_heads,
241
+ self.head_dim,
242
+ )
243
+ ).cuda()
244
+ self.cache_v = torch.zeros(
245
+ (
246
+ args.max_batch_size,
247
+ args.max_seq_len,
248
+ self.n_local_kv_heads,
249
+ self.head_dim,
250
+ )
251
+ ).cuda()
252
+
253
+ def forward(
254
+ self,
255
+ x: torch.Tensor,
256
+ start_pos: int,
257
+ freqs_cis: torch.Tensor,
258
+ mask: Optional[torch.Tensor],
259
+ beam: Optional[bool] = None,
260
+ n_beams: Optional[int] = None,
261
+ attention_change_ids: Optional[torch.Tensor] = None
262
+ ):
263
+ """
264
+ Forward pass of the attention module.
265
+
266
+ Args:
267
+ x (torch.Tensor): Input tensor.
268
+ start_pos (int): Starting position for caching.
269
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
270
+ mask (torch.Tensor, optional): Attention mask tensor.
271
+
272
+ Returns:
273
+ torch.Tensor: Output tensor after attention.
274
+
275
+ """
276
+ bsz, seqlen, _ = x.shape
277
+ _, max_seq_len, n_local_kv_heads, head_dim = self.cache_k.shape
278
+ # KV Cache updates for beam search
279
+ if beam:
280
+ # Extract used cache values
281
+ used_cache_k = self.cache_k[:bsz]
282
+ used_cache_v = self.cache_v[:bsz]
283
+ # Reshape to apply change ids
284
+ t_cache_k = used_cache_k.reshape(bsz // n_beams, n_beams, max_seq_len, n_local_kv_heads, head_dim)
285
+ t_cache_v = used_cache_v.reshape(bsz // n_beams, n_beams, max_seq_len, n_local_kv_heads, head_dim)
286
+ used_cache_k = torch.take_along_dim(t_cache_k, attention_change_ids.reshape(-1, n_beams, 1, 1, 1), 1)
287
+ used_cache_v = torch.take_along_dim(t_cache_v, attention_change_ids.reshape(-1, n_beams, 1, 1, 1), 1)
288
+ # Update cache
289
+ self.cache_k[:bsz] = used_cache_k.reshape(bsz, max_seq_len, n_local_kv_heads, head_dim)
290
+ self.cache_v[:bsz] = used_cache_v.reshape(bsz, max_seq_len, n_local_kv_heads, head_dim)
291
+
292
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
293
+
294
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
295
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
296
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
297
+
298
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
299
+
300
+ self.cache_k = self.cache_k.to(xq)
301
+ self.cache_v = self.cache_v.to(xq)
302
+
303
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
304
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
305
+
306
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
307
+ values = self.cache_v[:bsz, : start_pos + seqlen]
308
+
309
+ # repeat k/v heads if n_kv_heads < n_heads
310
+ keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
311
+ values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
312
+
313
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
314
+ keys = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
315
+ values = values.transpose(1, 2)
316
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # (bs, n_local_heads, seqlen, seqlen)
317
+ if mask is not None:
318
+ scores = scores + mask # (bs, n_local_heads, seqlen, seqlen)
319
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq) # (bs, n_local_heads, seqlen, seqlen)
320
+ output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
321
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
322
+ return self.wo(output)
323
+
324
+
325
+ class FeedForward(nn.Module):
326
+ def __init__(
327
+ self,
328
+ dim: int,
329
+ hidden_dim: int,
330
+ multiple_of: int,
331
+ ffn_dim_multiplier: Optional[float],
332
+ ):
333
+ """
334
+ Initialize the FeedForward module.
335
+
336
+ Args:
337
+ dim (int): Input dimension.
338
+ hidden_dim (int): Hidden dimension of the feedforward layer.
339
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
340
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
341
+
342
+ Attributes:
343
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
344
+ w2 (RowParallelLinear): Linear transformation for the second layer.
345
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
346
+
347
+ """
348
+ super().__init__()
349
+ hidden_dim = int(2 * hidden_dim / 3)
350
+ # custom dim factor multiplier
351
+ if ffn_dim_multiplier is not None:
352
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
353
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
354
+
355
+ self.w1 = ColumnParallelLinear(
356
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
357
+ )
358
+ self.w2 = RowParallelLinear(
359
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
360
+ )
361
+ self.w3 = ColumnParallelLinear(
362
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
363
+ )
364
+
365
+ def forward(self, x):
366
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
367
+
368
+
369
+ class TransformerBlock(nn.Module):
370
+ def __init__(self, layer_id: int, args: ModelArgs):
371
+ """
372
+ Initialize a TransformerBlock.
373
+
374
+ Args:
375
+ layer_id (int): Identifier for the layer.
376
+ args (ModelArgs): Model configuration parameters.
377
+
378
+ Attributes:
379
+ n_heads (int): Number of attention heads.
380
+ dim (int): Dimension size of the model.
381
+ head_dim (int): Dimension size of each attention head.
382
+ attention (Attention): Attention module.
383
+ feed_forward (FeedForward): FeedForward module.
384
+ layer_id (int): Identifier for the layer.
385
+ attention_norm (RMSNorm): Layer normalization for attention output.
386
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
387
+
388
+ """
389
+ super().__init__()
390
+ self.n_heads = args.n_heads
391
+ self.dim = args.dim
392
+ self.head_dim = args.dim // args.n_heads
393
+ self.attention = Attention(args)
394
+ self.feed_forward = FeedForward(
395
+ dim=args.dim,
396
+ hidden_dim=4 * args.dim,
397
+ multiple_of=args.multiple_of,
398
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
399
+ )
400
+ self.layer_id = layer_id
401
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
402
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
403
+
404
+ def forward(
405
+ self,
406
+ x: torch.Tensor,
407
+ start_pos: int,
408
+ freqs_cis: torch.Tensor,
409
+ mask: Optional[torch.Tensor],
410
+ beam: Optional[bool],
411
+ n_beams: Optional[int] = None,
412
+ attention_change_ids: Optional[torch.Tensor] = None
413
+ ):
414
+ """
415
+ Perform a forward pass through the TransformerBlock.
416
+
417
+ Args:
418
+ x (torch.Tensor): Input tensor.
419
+ start_pos (int): Starting position for attention caching.
420
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
421
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
422
+
423
+ Returns:
424
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
425
+
426
+ """
427
+ if beam:
428
+ h = x + self.attention.forward(
429
+ self.attention_norm(x), start_pos, freqs_cis, mask, beam, n_beams, attention_change_ids
430
+ )
431
+ else:
432
+ h = x + self.attention.forward(
433
+ self.attention_norm(x), start_pos, freqs_cis, mask
434
+ )
435
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
436
+ return out
437
+
438
+
439
+ class Transformer(nn.Module):
440
+ def __init__(self, params: ModelArgs):
441
+ """
442
+ Initialize a Transformer model.
443
+
444
+ Args:
445
+ params (ModelArgs): Model configuration parameters.
446
+
447
+ Attributes:
448
+ params (ModelArgs): Model configuration parameters.
449
+ vocab_size (int): Vocabulary size.
450
+ n_layers (int): Number of layers in the model.
451
+ tok_embeddings (ParallelEmbedding): Token embeddings.
452
+ layers (torch.nn.ModuleList): List of Transformer blocks.
453
+ norm (RMSNorm): Layer normalization for the model output.
454
+ output (ColumnParallelLinear): Linear layer for final output.
455
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
456
+
457
+ """
458
+ super().__init__()
459
+ self.params = params
460
+ self.vocab_size = params.vocab_size
461
+ self.n_layers = params.n_layers
462
+
463
+ self.tok_embeddings = ParallelEmbedding(
464
+ params.vocab_size, params.dim, init_method=lambda x: x
465
+ )
466
+
467
+ self.layers = torch.nn.ModuleList()
468
+ for layer_id in range(params.n_layers):
469
+ self.layers.append(TransformerBlock(layer_id, params))
470
+
471
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
472
+ self.output = ColumnParallelLinear(
473
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
474
+ )
475
+
476
+ self.freqs_cis = precompute_freqs_cis(
477
+ # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
478
+ # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
479
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
480
+ )
481
+
482
+
483
+ @torch.inference_mode()
484
+ def forward(self,
485
+ tokens: torch.Tensor,
486
+ start_pos: int,
487
+ beam: bool,
488
+ n_beams: Optional[int] = None,
489
+ attention_change_ids: Optional[torch.Tensor] = None,
490
+ verbose: Optional[bool] = False):
491
+ """
492
+ Perform a forward pass through the Transformer model.
493
+
494
+ Args:
495
+ tokens (torch.Tensor): Input token indices.
496
+ start_pos (int): Starting position for attention caching.
497
+ verbose (bool): Whether to return intermediate hidden layer states
498
+
499
+ Returns:
500
+ torch.Tensor or (torch.Tensor, Dict): output logits after applying the Transformer model.
501
+
502
+ """
503
+ ### ANALYSIS CODE ###
504
+ if verbose:
505
+ states = {"layers": [], "tokens": tokens}
506
+ #
507
+
508
+ _bsz, seqlen = tokens.shape
509
+ h = self.tok_embeddings(tokens)
510
+ self.freqs_cis = self.freqs_cis.to(h.device)
511
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
512
+
513
+ ### ANALYSIS CODE ###
514
+ if verbose:
515
+ states["layers"].append(h)
516
+ #
517
+
518
+ mask = None
519
+ if seqlen > 1:
520
+ mask = torch.full(
521
+ (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
522
+ )
523
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
524
+
525
+ for layer in self.layers:
526
+ if not beam:
527
+ h = layer(h, start_pos, freqs_cis, mask, beam)
528
+ else:
529
+ h = layer(h, start_pos, freqs_cis, mask, beam, n_beams, attention_change_ids)
530
+ ### ANALYSIS CODE ###
531
+ if verbose:
532
+ states["layers"].append(h)
533
+ #
534
+ h = self.norm(h)
535
+ # if want differences, at end, subtract differences from [-1] position of embedding vectors each iteration
536
+
537
+ ### ANALYSIS CODE ###
538
+ if verbose:
539
+ states["layers"].append(h)
540
+ #
541
+
542
+ output = self.output(h).float()
543
+
544
+ if verbose:
545
+ return output, states
546
+ else:
547
+ return output
548
+
superposed/llama/superpose.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation loosely based on https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L554
2
+ import requests
3
+ import time
4
+ from datetime import datetime, timedelta
5
+ from typing import Optional, Literal
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import LlamaTokenizer
10
+
11
+ from superposed.llama.utils import *
12
+ from superposed.ngrams.ngram_models import NGram
13
+
14
+ INF = 1. * 1e7
15
+
16
+ # Test by scaling # beams & verify work
17
+ class Superpose(nn.Module):
18
+ def __init__(self,
19
+ initial_tokens,
20
+ tokenizer,
21
+ vocab_size,
22
+ smoothing=Optional[Literal["geom", "all"]],
23
+ alpha = None,
24
+ verbose = False,
25
+ i_weights = None,
26
+ i_length = None,
27
+ ngrams = None,
28
+ sample_beams = False,
29
+ sample_tokens = False,
30
+ get_time = False,
31
+ penalty = 200): # default no effect
32
+ """
33
+ Initialize a beam search class.
34
+
35
+ Args:
36
+ initial_tokens (torch.Tensor): Initial tokens
37
+ n_prompts (int): Number of prompts
38
+ tokenizer (Tokenizer): Llama tokenizer
39
+ vocab_size (int): Total vocab size
40
+ smoothing (str): Smoothing method ("geom" for default, "all" for only ngram, None for no ngram)
41
+ ngram_length (int): N gram length to consider
42
+ alpha (float): Alpha parameter
43
+ debug (bool): Whether to print information
44
+ """
45
+ super().__init__()
46
+ # primary parameters
47
+ self.n_prompts, self.n_drafts, _ = initial_tokens.shape
48
+ self.tokenizer = tokenizer
49
+ self.vocab_size = vocab_size
50
+ self.alive_seq = initial_tokens
51
+ self.fin_seq = initial_tokens
52
+ self.smoothing = smoothing
53
+ self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts)
54
+ self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"))
55
+ self.alpha = alpha
56
+ self.verbose = verbose
57
+ self.penalty = penalty
58
+ # devices
59
+ self.cpu = torch.device('cpu')
60
+ self.gpu = torch.device('cuda')
61
+ # Interpolation length and weights
62
+ self.interpolation_weights = i_weights
63
+ self.i_length = i_length
64
+ # N-grams
65
+ self.bigram = ngrams[0] if len(ngrams) >= 1 else None
66
+ self.trigram = ngrams[1] if len(ngrams) >= 2 else None
67
+ self.fourgram = ngrams[2] if len(ngrams) >= 3 else None
68
+ self.fivegram = ngrams[3] if len(ngrams) >= 4 else None
69
+ self.sixgram = ngrams[4] if len(ngrams) >= 5 else None
70
+ self.sevengram = ngrams[5] if len(ngrams) >= 6 else None
71
+ # Timing
72
+ self.get_time = get_time
73
+ self.lookup_time = None
74
+
75
+ def forward(self, probs, still_prompt, is_first, cur_pos, n_token_sample):
76
+ """
77
+ Apply beam decoding to update generations.
78
+
79
+ Args:
80
+ probs (torch.Tensor): Next token probability distribution
81
+ still_prompt (torch.Tensor): Flags of prompts that should not generate yet (n_prompts, )
82
+ is_first (torch.Tensor): Flags of prompts that are on their first generation (n_prompts, )
83
+ cur_pos (int): Current generation position
84
+ n_token_sample (int): Number of tokens from model distribution to use
85
+
86
+ Return:
87
+ if standard beam search:
88
+ attention_change_ids (torch.Tensor): New indices in kv cache (n_prompts, n_drafts)
89
+ if mixed:
90
+ token_weights (torch.Tensor): Mixing weights (n_prompts, vocab_size)
91
+ """
92
+ # Adjust input probabilities
93
+ probs = self.get_top_k(probs, 32000, n_token_sample)
94
+ reshaped_probs = probs.reshape(self.n_prompts, 1, -1)
95
+ reshaped_probs = reshaped_probs.repeat(1, self.n_drafts, 1)
96
+ # Ngram smoothing
97
+ if self.smoothing is not None:
98
+ if self.smoothing == "geom":
99
+ ngram_probs = self.ngram_probs(self.alive_seq, cur_pos, probs=probs)
100
+ # Make mask and normalize
101
+ prob_mask = reshaped_probs != 0
102
+ ngram_probs *= prob_mask
103
+ # Calculate logprobs and interpolate distributions
104
+ llm_log_probs = torch.log(reshaped_probs)
105
+ ngram_log_probs = torch.log(ngram_probs)
106
+ log_probs = (1 - self.alpha) * llm_log_probs + self.alpha * ngram_log_probs
107
+ # Apply penalty to drafts where no interpolation occurred
108
+ is_all_inf = (log_probs != float("-inf")).sum(dim=-1, keepdims=True) == 0
109
+ log_probs = torch.where(is_all_inf, (1 - self.alpha) * llm_log_probs - self.penalty, log_probs)
110
+ elif self.smoothing == "all":
111
+ ngram_probs = self.ngram_probs(self.alive_seq, cur_pos, probs=None)
112
+ log_probs = torch.log(ngram_probs)
113
+ else:
114
+ log_probs = torch.log(reshaped_probs)
115
+ curr_log_probs = self.alive_log_probs.unsqueeze(dim=2) + log_probs # [n_prompts, n_drafts, vocab_size]
116
+ # Warning if nan
117
+ if (torch.any(torch.isnan(curr_log_probs)).item()):
118
+ raise RuntimeWarning("nan in sequence log probs", file=self.output_file)
119
+ # Potential Sequences
120
+ flat_curr_log_probs = curr_log_probs.reshape(-1, self.vocab_size*self.n_drafts)
121
+ topk_log_probs, topk_idx = torch.topk(flat_curr_log_probs, 2 * self.n_drafts, dim=-1)
122
+ topk_beam_id = topk_idx // self.vocab_size # [n_prompts, 2 * n_drafts]
123
+ topk_idx = topk_idx % self.vocab_size # [n_prompts, 2 * n_drafts]
124
+ # First timestep uses top-k next tokens
125
+ is_first_idx = is_first.nonzero(as_tuple=True)[0]
126
+ if len(is_first_idx) != 0:
127
+ first_time_log_probs = log_probs[is_first_idx][:, 0, :].squeeze(dim=1)
128
+ first_time_log_probs, first_time_topk_idx = torch.topk(first_time_log_probs, 2 * self.n_drafts, dim=1)
129
+ topk_idx[is_first_idx] = first_time_topk_idx
130
+ topk_log_probs[is_first_idx] = self.alive_log_probs[is_first_idx, 0].unsqueeze(dim=1) + first_time_log_probs
131
+ # New sequences
132
+ topk_seq = torch.take_along_dim(self.alive_seq, topk_beam_id.unsqueeze(2), dim=1) # [n_prompts, 2 * n_drafts, vocab_size]
133
+ topk_seq[:, :, cur_pos] = topk_idx
134
+ topk_finished = topk_idx == self.tokenizer.eos_id
135
+ # Only update sequences for those that have begun generating
136
+ new_alive_seq, new_alive_log_probs = self.grow_alive(topk_seq, topk_log_probs, topk_finished)
137
+ new_fin_seq, new_fin_log_probs = self.grow_fin(topk_seq, topk_log_probs, topk_finished)
138
+ still_prompt_probs = still_prompt.reshape(-1, 1)
139
+ still_prompt_seqs = still_prompt.reshape(-1, 1, 1)
140
+ self.alive_seq = torch.where(still_prompt_seqs, self.alive_seq, new_alive_seq)
141
+ self.alive_log_probs = torch.where(still_prompt_probs, self.alive_log_probs, new_alive_log_probs)
142
+ self.fin_seq = torch.where(still_prompt_seqs, self.fin_seq, new_fin_seq)
143
+ self.fin_log_probs = torch.where(still_prompt_probs, self.fin_log_probs, new_fin_log_probs)
144
+ # Create superposition matrix and return it
145
+ topk_idx = self.alive_seq[:, :, cur_pos].reshape(self.n_prompts, -1)
146
+ token_weights = self.superposition_matrix(topk_idx)
147
+ return token_weights
148
+
149
+ def grow_alive(self, topk_seq, topk_log_probs, topk_finished):
150
+ """
151
+ Extend running generations.
152
+ Args:
153
+ topk_seq (torch.Tensor): Top k sequences (n_prompts, 2 * n_drafts, vocab_size)
154
+ topk_log_probs (torch.Tensor): Log probabilities (n_prompts, 2 * n_drafts)
155
+ topk_finished (torch.Tensor): Whether a sequence is finished (n_prompts, 2 * n_drafts)
156
+ Returns:
157
+ new_alive_seq, new_alive_log_probs
158
+ """
159
+ topk_log_probs = topk_log_probs + topk_finished * -INF
160
+ new_alive_log_probs, new_alive_idx = torch.topk(topk_log_probs, self.n_drafts, dim=1)
161
+ new_alive_seq = torch.take_along_dim(topk_seq, new_alive_idx.unsqueeze(2), dim=1)
162
+ return new_alive_seq, new_alive_log_probs
163
+
164
+ def grow_fin(self, topk_seq, topk_log_probs, topk_finished):
165
+ """
166
+ Update stopped generations.
167
+ Args:
168
+ topk_seq (torch.Tensor): Top k sequences (n_prompts, 2 * n_drafts, vocab_size)
169
+ topk_log_probs (torch.Tensor): Log probabilities (n_prompts, 2 * n_drafts)
170
+ topk_finished (torch.Tensor): Whether a sequence is finished (n_prompts, 2 * n_drafts)
171
+
172
+ Returns:
173
+ new_fin_seq, new_fin_log_probs
174
+ """
175
+ topk_log_probs = topk_log_probs + ~topk_finished * -INF
176
+ new_fin_seq = torch.cat([self.fin_seq, topk_seq], dim=1)
177
+ new_fin_log_probs = torch.cat([self.fin_log_probs, topk_log_probs], dim=1)
178
+ new_fin_log_probs, new_fin_idx = torch.topk(new_fin_log_probs, self.n_drafts, dim=1)
179
+ new_fin_seq = torch.take_along_dim(new_fin_seq, new_fin_idx.unsqueeze(2), dim=1)
180
+ return new_fin_seq, new_fin_log_probs
181
+
182
+ def get_top_k(self, probs, m, k):
183
+ """
184
+ Zero out all but top-k tokens in a probability distribution.
185
+ Args:
186
+ probs (torch.Tensor): Probability distribution tensor.
187
+ m (float): Number of tokens to consider (only relevant when sampling).
188
+ k (int): Number of tokens to sample/keep.
189
+ Returns:
190
+ torch.Tensor: New probability distribution based on renormalized probabilities.
191
+ """
192
+ n_prompts, _ = probs.shape
193
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
194
+ top_k_mask = torch.arange(probs.shape[-1])
195
+ top_k_mask = top_k_mask.expand(probs.shape[0], -1)
196
+ top_k_mask = top_k_mask >= m # Set to 1 past k elements
197
+ probs_sort[top_k_mask] = 0.0 # Zero wherever mask = 1
198
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
199
+ next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])
200
+ # Set all other probs to 0
201
+ new_probs_map = torch.zeros(probs.shape).bool()
202
+ new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
203
+ new_probs = torch.where(new_probs_map, probs, 0)
204
+ # Renormalize
205
+ new_probs.div_(new_probs.sum(dim=-1, keepdim=True))
206
+ return new_probs
207
+
208
+ def superposition_matrix(self, tokens):
209
+ """
210
+ Create superposition matrix based on provided tokens.
211
+ Args:
212
+ tokens (torch.Tensor): Tokens to mix (n_prompts, n_drafts)
213
+ Returns:
214
+ SUperposition matrix
215
+ """
216
+ # Create superposition matrix
217
+ mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size)
218
+ # Convert draft log probs to probabilities
219
+ weightings = log_prob_to_prob(self.alive_log_probs)
220
+ # Update probabilities in superposition matrix with draft probabilities
221
+ for p_idx in range(self.n_prompts):
222
+ for d_idx in range(self.n_drafts):
223
+ tok_idx = tokens[p_idx][d_idx]
224
+ mixing_matrix[p_idx][tok_idx] += weightings[p_idx][d_idx]
225
+ # Renormalize
226
+ mixing_matrix.div_(mixing_matrix.sum(dim=-1, keepdims=True))
227
+ return mixing_matrix
228
+
229
+ def ngram_probs(self, alive_seq, cur_pos, probs):
230
+ """
231
+ Calculate and return next token distribution using ngram models.
232
+ Args:
233
+ alive_seq (torch.Tensor): Current drafts (n_prompts, n_drafts, seqlen)
234
+ cur_pos (int): Current timestep
235
+ probs (torch.Tensor): Current next probability distribution from model (n_prompts, vocab_size).
236
+ As described in the paper, only tokens w/nonzero probability in `prob` are considered for the
237
+ ngram distribution. However, passing in `None` as `probs` will consider all tokens.
238
+ Returns:
239
+ Next token distribution for each draft (n_prompts, n_drafts, vocab_size)
240
+ """
241
+ if self.get_time:
242
+ # Start timer
243
+ start_time = datetime.now()
244
+ # Create distribution matrix
245
+ next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000)
246
+ if probs is not None:
247
+ # Loop over all prefixes
248
+ for p_idx in range(len(alive_seq)):
249
+ # List of possible tokens for the prefix
250
+ nz = torch.nonzero(probs[p_idx, :], as_tuple=True)[0].tolist()
251
+ # Generate next token distribution
252
+ for draft_idx in range(self.n_drafts):
253
+ i_mask = torch.sum(torch.tensor(self.i_length) <= cur_pos)
254
+ new_i_weights = self.interpolation_weights[:i_mask]
255
+ new_i_length = self.i_length[:i_mask]
256
+ # For each next token
257
+ for nt in nz:
258
+ # Calculate probability using ngram interpolation
259
+ for i, weight in zip(new_i_length, new_i_weights):
260
+ if cur_pos - i >= 0:
261
+ key = tuple(alive_seq[p_idx, draft_idx, cur_pos-i:cur_pos].tolist())
262
+ if i == 1:
263
+ prob = self.bigram.prob(key, nt)
264
+ elif i == 2:
265
+ prob = self.trigram.prob(key, nt)
266
+ elif i == 3:
267
+ prob = self.fourgram.prob(key, nt)
268
+ elif i == 4:
269
+ prob = self.fivegram.prob(key, nt)
270
+ elif i == 5:
271
+ prob = self.sixgram.prob(key, nt)
272
+ elif i == 6:
273
+ prob = self.sevengram.prob(key, nt)
274
+ if prob >= 0:
275
+ next_token_probs[p_idx, draft_idx, nt] += weight * prob
276
+ else:
277
+ for p_idx in range(len(alive_seq)):
278
+ for draft_idx in range(self.n_drafts):
279
+ i_mask = torch.sum(torch.tensor(self.i_length) <= cur_pos)
280
+ new_i_weights = self.interpolation_weights[:i_mask]
281
+ new_i_length = self.i_length[:i_mask]
282
+ for i, weight in zip(new_i_length, new_i_weights):
283
+ if cur_pos - i >= 0:
284
+ key = tuple(alive_seq[p_idx, draft_idx, cur_pos-i:cur_pos].tolist())
285
+ if i == 1:
286
+ ntd = self.bigram.ntd(key)
287
+ elif i == 2:
288
+ ntd = self.trigram.ntd(key)
289
+ elif i == 3:
290
+ ntd = self.fourgram.ntd(key)
291
+ elif i == 4:
292
+ ntd = self.fivegram.ntd(key)
293
+ elif i == 5:
294
+ ntd = self.sixgram.ntd(key)
295
+ elif i == 6:
296
+ ntd = self.sevengram.ntd(key)
297
+ if ntd is not None:
298
+ next_token_probs[p_idx, draft_idx, :] += weight * ntd
299
+ if self.get_time:
300
+ total_time = datetime.now() - start_time
301
+ self.lookup_time = total_time if self.lookup_time is None else self.lookup_time + total_time
302
+ return next_token_probs
303
+
304
+ def return_results(self, prompt_len=None):
305
+ """
306
+ Return generations and perplexities
307
+
308
+ Args:
309
+ prompt_len (int): Length of prompt in tokens. If is None, then ppl is not calculated.
310
+ Returns:
311
+ (self.alive_seq, alive_ppl), (self.fin_seq, fin_ppl)
312
+ OR
313
+ (self.alive_seq, alive_ppl), (self.fin_seq, fin_ppl), self.lookup_time
314
+ """
315
+ # PPL
316
+ alive_ppl = 0
317
+ fin_ppl = 0
318
+ if prompt_len is not None:
319
+ alive_ppl = torch.exp(self.alive_log_probs / (-1 * (self.alive_seq.size(dim=-1)-prompt_len)))
320
+ # Fin ppl
321
+ fin_seq_lengths = (self.fin_seq != self.tokenizer.pad_id).sum(dim=-1)
322
+ fin_ppl = torch.exp(self.fin_log_probs / (-1 * (fin_seq_lengths - prompt_len)))
323
+ fin_ppl += ((fin_ppl == 0) * float("inf"))
324
+ # print time
325
+ if not self.get_time:
326
+ return (self.alive_seq.to(torch.long), alive_ppl), (self.fin_seq.to(torch.long), fin_ppl)
327
+ else:
328
+ return (self.alive_seq.to(torch.long), alive_ppl), (self.fin_seq.to(torch.long), fin_ppl), self.lookup_time
superposed/llama/superposed_generation.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import json
5
+ import os
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairscale.nn.model_parallel.initialize import (
14
+ get_model_parallel_rank,
15
+ initialize_model_parallel,
16
+ model_parallel_is_initialized,
17
+ )
18
+
19
+ from superposed.llama.model import ModelArgs
20
+ from superposed.llama.superposed_model import SuperposedTransformer
21
+ from superposed.llama.tokenizer import Tokenizer
22
+ from superposed.llama.superpose import Superpose
23
+ from superposed.llama.utils import *
24
+ from superposed.ngrams.ngram_models import make_models
25
+
26
+ class SuperposedLlama:
27
+ @staticmethod
28
+ def build(
29
+ ckpt_dir: str,
30
+ tokenizer_path: str,
31
+ max_seq_len: int,
32
+ max_batch_size: int,
33
+ device = None,
34
+ model_parallel_size: Optional[int] = None,
35
+ seed: int = 1,
36
+ ):
37
+ if not torch.distributed.is_initialized():
38
+ torch.distributed.init_process_group("nccl")
39
+ if not model_parallel_is_initialized():
40
+ if model_parallel_size is None:
41
+ model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
42
+ initialize_model_parallel(model_parallel_size)
43
+
44
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
45
+ if device == None:
46
+ torch.cuda.set_device(local_rank)
47
+ device = torch.cuda.current_device()
48
+ torch.manual_seed(seed)
49
+
50
+ if local_rank > 0:
51
+ sys.stdout = open(os.devnull, "w")
52
+
53
+ start_time = time.time()
54
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
55
+ assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
56
+ assert model_parallel_size == len(
57
+ checkpoints
58
+ ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
59
+ ckpt_path = checkpoints[get_model_parallel_rank()]
60
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
61
+ with open(Path(ckpt_dir) / "params.json", "r") as f:
62
+ params = json.loads(f.read())
63
+
64
+ model_args: ModelArgs = ModelArgs(
65
+ max_seq_len=max_seq_len,
66
+ max_batch_size=max_batch_size,
67
+ **params,
68
+ )
69
+ tokenizer = Tokenizer(model_path=tokenizer_path)
70
+ model_args.vocab_size = tokenizer.n_words
71
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
72
+ # Set up superposed decoding
73
+ model = SuperposedTransformer(model_args)
74
+ model.load_state_dict(checkpoint, strict=False)
75
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
76
+ return SuperposedLlama(model, tokenizer, device)
77
+
78
+ def __init__(self, model: SuperposedTransformer, tokenizer: Tokenizer, device):
79
+ print(device)
80
+ self.model = model.to(device).eval()
81
+ self.tokenizer = tokenizer
82
+ self.device = device
83
+
84
+ @torch.inference_mode()
85
+ def sup_generate(
86
+ self,
87
+ prompt_tokens: List[List[int]],
88
+ smoothing,
89
+ max_gen_len: int,
90
+ n_token_sample: int,
91
+ alpha: int, # weight on bigram probs
92
+ temp: int,
93
+ n_drafts: int = 1, # number of beams
94
+ verbose: bool = False,
95
+ i_weights = None,
96
+ i_length = None,
97
+ ngrams = None,
98
+ get_time: bool = False,
99
+ penalty = 200
100
+ ):
101
+ """
102
+ Run multi-sequence generation using superposed embeddings.
103
+ Args:
104
+ prompt_tokens (List[List[int]]): Initial tokenized prompts
105
+ max_gen_len (int): Maximum numbers of tokens to generate
106
+ alpha (float): Alpha value
107
+ temp (float): Temperature
108
+ n_drafts (int): Number of drafts
109
+ verbose (bool): Whether to save intermediate embeddings for analysis
110
+ bsz (int): Batch size (default = 16)
111
+ i_weights (List[float]): Ngram interpolation weights
112
+ i_length (List[int]): Ngram models to interpolate (1 for bigram, 2 for trigram, etc.)
113
+ ngrams (Tuple): Ngram models
114
+ get_time (bool): Return information on time spent doing Ngram lookup
115
+ penalty (float): Penalty on uninterpolated drafts
116
+ Returns:
117
+ (alive_seq, alive_ppl), (fin_seq, fin_ppl): Tuple of (n_prompts, n_drafts, seqlen),
118
+ (n_prompts, n_drafts) for sequences still generating and sequences that have finished.
119
+ """
120
+ # Check batch size and prompt lengths
121
+ params = self.model.params
122
+ bsz = len(prompt_tokens)
123
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
124
+
125
+ min_prompt_len = min(len(t) for t in prompt_tokens)
126
+ max_prompt_len = max(len(t) for t in prompt_tokens)
127
+ prompt_len = min_prompt_len
128
+ assert max_prompt_len <= params.max_seq_len
129
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
130
+ pad_id = self.tokenizer.pad_id
131
+
132
+ # Initialize token tensor and pad where necessary
133
+ tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
134
+ for k, t in enumerate(prompt_tokens):
135
+ tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
136
+
137
+ # If no generation is possible
138
+ if min_prompt_len == total_len:
139
+ raise RuntimeError("no generation possible")
140
+
141
+ # Initialize decoding object
142
+ initial_tokens = tokens.unsqueeze(1).repeat(1, n_drafts, 1)
143
+ superpose = Superpose(initial_tokens,
144
+ tokenizer=self.tokenizer,
145
+ vocab_size=params.vocab_size,
146
+ smoothing=smoothing,
147
+ alpha=alpha,
148
+ i_weights=i_weights,
149
+ i_length=i_length,
150
+ ngrams=ngrams,
151
+ get_time=get_time,
152
+ penalty=penalty)
153
+ unseen_first = torch.ones(bsz)
154
+ # Superposition matrix
155
+ token_weights = torch.zeros(bsz, self.model.vocab_size)
156
+ if verbose:
157
+ state_list = []
158
+ prev_pos = 0
159
+ # Begin inference
160
+ for cur_pos in range(min_prompt_len, total_len):
161
+ input_text_mask = tokens != pad_id
162
+ # Take model step
163
+ if cur_pos == min_prompt_len:
164
+ token_weights = None
165
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos],
166
+ start_pos=prev_pos,
167
+ token_weights=token_weights,
168
+ verbose=verbose)
169
+ if verbose:
170
+ logits, states = logits
171
+ # Softmax
172
+ if temp > 0:
173
+ probs = torch.softmax(logits[:, -1] / temp, dim=-1)
174
+ else:
175
+ raise RuntimeError("Temperature must be greater than 0 while mixing")
176
+ if verbose:
177
+ states["end_probs"] = probs
178
+ state_list.append(states)
179
+ # Flag prompts on first generation
180
+ is_first = torch.mul(tokens[:, cur_pos] == pad_id, unseen_first)
181
+ unseen_first[is_first.nonzero(as_tuple=True)[0]] = 0
182
+ # Flag prompts not yet generating
183
+ still_prompt = input_text_mask[:, cur_pos]
184
+ # Superposition pass
185
+ token_weights = superpose(probs, still_prompt, is_first, cur_pos, n_token_sample)
186
+ # Do not superpose for prompts not yet generating
187
+ keep_idx = input_text_mask[:, cur_pos].ravel().nonzero()
188
+ keep_token_weights = torch.zeros_like(token_weights)
189
+ keep_token_weights[keep_idx, tokens[keep_idx, cur_pos]] = 1
190
+ token_weights = torch.where(input_text_mask[:, cur_pos].unsqueeze(1).expand(-1, self.model.vocab_size),
191
+ keep_token_weights, token_weights)
192
+ prev_pos = cur_pos
193
+ results = superpose.return_results(prompt_len)
194
+ if verbose:
195
+ torch.save(state_list, "../embeddings.pt")
196
+ return results
197
+ else:
198
+ return results
superposed/llama/superposed_model.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ import fairscale.nn.model_parallel.initialize as fs_init
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairscale.nn.model_parallel.layers import (
12
+ ColumnParallelLinear,
13
+ ParallelEmbedding,
14
+ RowParallelLinear,
15
+ )
16
+ from torch import nn
17
+
18
+
19
+ @dataclass
20
+ class ModelArgs:
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+
30
+ max_batch_size: int = 32
31
+ max_seq_len: int = 2048
32
+
33
+
34
+ class RMSNorm(torch.nn.Module):
35
+ def __init__(self, dim: int, eps: float = 1e-6):
36
+ """
37
+ Initialize the RMSNorm normalization layer.
38
+
39
+ Args:
40
+ dim (int): The dimension of the input tensor.
41
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
42
+
43
+ Attributes:
44
+ eps (float): A small value added to the denominator for numerical stability.
45
+ weight (nn.Parameter): Learnable scaling parameter.
46
+
47
+ """
48
+ super().__init__()
49
+ self.eps = eps
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+
52
+ def _norm(self, x):
53
+ """
54
+ Apply the RMSNorm normalization to the input tensor.
55
+
56
+ Args:
57
+ x (torch.Tensor): The input tensor.
58
+
59
+ Returns:
60
+ torch.Tensor: The normalized tensor.
61
+
62
+ """
63
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
64
+
65
+ def forward(self, x):
66
+ """
67
+ Forward pass through the RMSNorm layer.
68
+
69
+ Args:
70
+ x (torch.Tensor): The input tensor.
71
+
72
+ Returns:
73
+ torch.Tensor: The output tensor after applying RMSNorm.
74
+
75
+ """
76
+ output = self._norm(x.float()).type_as(x)
77
+ k = output * self.weight
78
+ return k
79
+
80
+
81
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
82
+ """
83
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
84
+
85
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
86
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
87
+ The returned tensor contains complex values in complex64 data type.
88
+
89
+ Args:
90
+ dim (int): Dimension of the frequency tensor.
91
+ end (int): End index for precomputing frequencies.
92
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
93
+
94
+ Returns:
95
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
96
+
97
+
98
+
99
+
100
+ """
101
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
102
+ t = torch.arange(end, device=freqs.device) # type: ignore
103
+ freqs = torch.outer(t, freqs).float() # type: ignore
104
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
105
+ return freqs_cis
106
+
107
+
108
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
109
+ """
110
+ Reshape frequency tensor for broadcasting it with another tensor.
111
+
112
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
113
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
114
+
115
+ Args:
116
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
117
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
118
+
119
+ Returns:
120
+ torch.Tensor: Reshaped frequency tensor.
121
+
122
+ Raises:
123
+ AssertionError: If the frequency tensor doesn't match the expected shape.
124
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
125
+ """
126
+ ndim = x.ndim
127
+ assert 0 <= 1 < ndim
128
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
129
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
+ return freqs_cis.view(*shape)
131
+
132
+
133
+ def apply_rotary_emb(
134
+ xq: torch.Tensor,
135
+ xk: torch.Tensor,
136
+ freqs_cis: torch.Tensor,
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """
139
+ Apply rotary embeddings to input tensors using the given frequency tensor.
140
+
141
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
142
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
143
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
144
+ returned as real tensors.
145
+
146
+ Args:
147
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
148
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
149
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
150
+
151
+ Returns:
152
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
153
+
154
+
155
+
156
+ """
157
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
158
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
159
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
160
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
161
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
162
+ return xq_out.type_as(xq), xk_out.type_as(xk)
163
+
164
+
165
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
166
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
167
+ bs, slen, n_kv_heads, head_dim = x.shape
168
+ if n_rep == 1:
169
+ return x
170
+ return (
171
+ x[:, :, :, None, :]
172
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
173
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
174
+ )
175
+
176
+
177
+ class Attention(nn.Module):
178
+ """Multi-head attention module."""
179
+ def __init__(self, args: ModelArgs):
180
+ """
181
+ Initialize the Attention module.
182
+
183
+ Args:
184
+ args (ModelArgs): Model configuration parameters.
185
+
186
+ Attributes:
187
+ n_kv_heads (int): Number of key and value heads.
188
+ n_local_heads (int): Number of local query heads.
189
+ n_local_kv_heads (int): Number of local key and value heads.
190
+ n_rep (int): Number of repetitions for local heads.
191
+ head_dim (int): Dimension size of each attention head.
192
+ wq (ColumnParallelLinear): Linear transformation for queries.
193
+ wk (ColumnParallelLinear): Linear transformation for keys.
194
+ wv (ColumnParallelLinear): Linear transformation for values.
195
+ wo (RowParallelLinear): Linear transformation for output.
196
+ cache_k (torch.Tensor): Cached keys for attention.
197
+ cache_v (torch.Tensor): Cached values for attention.
198
+
199
+ """
200
+ super().__init__()
201
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
202
+ model_parallel_size = fs_init.get_model_parallel_world_size()
203
+ self.n_local_heads = args.n_heads // model_parallel_size
204
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
205
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
206
+ self.head_dim = args.dim // args.n_heads
207
+
208
+ self.wq = ColumnParallelLinear(
209
+ args.dim,
210
+ args.n_heads * self.head_dim,
211
+ bias=False,
212
+ gather_output=False,
213
+ init_method=lambda x: x,
214
+ )
215
+ self.wk = ColumnParallelLinear(
216
+ args.dim,
217
+ self.n_kv_heads * self.head_dim,
218
+ bias=False,
219
+ gather_output=False,
220
+ init_method=lambda x: x,
221
+ )
222
+ self.wv = ColumnParallelLinear(
223
+ args.dim,
224
+ self.n_kv_heads * self.head_dim,
225
+ bias=False,
226
+ gather_output=False,
227
+ init_method=lambda x: x,
228
+ )
229
+ self.wo = RowParallelLinear(
230
+ args.n_heads * self.head_dim,
231
+ args.dim,
232
+ bias=False,
233
+ input_is_parallel=True,
234
+ init_method=lambda x: x,
235
+ )
236
+
237
+ self.cache_k = torch.zeros(
238
+ (
239
+ args.max_batch_size,
240
+ args.max_seq_len,
241
+ self.n_local_kv_heads,
242
+ self.head_dim,
243
+ )
244
+ ).cuda()
245
+ self.cache_v = torch.zeros(
246
+ (
247
+ args.max_batch_size,
248
+ args.max_seq_len,
249
+ self.n_local_kv_heads,
250
+ self.head_dim,
251
+ )
252
+ ).cuda()
253
+
254
+ def forward(
255
+ self,
256
+ x: torch.Tensor,
257
+ start_pos: int,
258
+ freqs_cis: torch.Tensor,
259
+ mask: Optional[torch.Tensor]
260
+ ):
261
+ """
262
+ Forward pass of the attention module.
263
+
264
+ Args:
265
+ x (torch.Tensor): Input tensor.
266
+ start_pos (int): Starting position for caching.
267
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
268
+ mask (torch.Tensor, optional): Attention mask tensor.
269
+
270
+ Returns:
271
+ torch.Tensor: Output tensor after attention.
272
+
273
+ """
274
+ bsz, seqlen, _ = x.shape
275
+
276
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
277
+
278
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
279
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
280
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
281
+
282
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
283
+
284
+ self.cache_k = self.cache_k.to(xq)
285
+ self.cache_v = self.cache_v.to(xq)
286
+
287
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
288
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
289
+
290
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
291
+ values = self.cache_v[:bsz, : start_pos + seqlen]
292
+
293
+ # repeat k/v heads if n_kv_heads < n_heads
294
+ keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
295
+ values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
296
+
297
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
298
+ keys = keys.transpose(1, 2)
299
+ values = values.transpose(1, 2)
300
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
301
+ if mask is not None:
302
+ scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
303
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
304
+ output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
305
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
306
+ return self.wo(output)
307
+
308
+
309
+ class FeedForward(nn.Module):
310
+ def __init__(
311
+ self,
312
+ dim: int,
313
+ hidden_dim: int,
314
+ multiple_of: int,
315
+ ffn_dim_multiplier: Optional[float],
316
+ ):
317
+ """
318
+ Initialize the FeedForward module.
319
+
320
+ Args:
321
+ dim (int): Input dimension.
322
+ hidden_dim (int): Hidden dimension of the feedforward layer.
323
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
324
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
325
+
326
+ Attributes:
327
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
328
+ w2 (RowParallelLinear): Linear transformation for the second layer.
329
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
330
+
331
+ """
332
+ super().__init__()
333
+ hidden_dim = int(2 * hidden_dim / 3)
334
+ # custom dim factor multiplier
335
+ if ffn_dim_multiplier is not None:
336
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
337
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
338
+
339
+ self.w1 = ColumnParallelLinear(
340
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
341
+ )
342
+ self.w2 = RowParallelLinear(
343
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
344
+ )
345
+ self.w3 = ColumnParallelLinear(
346
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
347
+ )
348
+
349
+ def forward(self, x):
350
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
351
+
352
+
353
+ class MixedTransformerBlock(nn.Module):
354
+ def __init__(self, layer_id: int, args: ModelArgs):
355
+ """
356
+ Initialize a TransformerBlock.
357
+
358
+ Args:
359
+ layer_id (int): Identifier for the layer.
360
+ args (ModelArgs): Model configuration parameters.
361
+
362
+ Attributes:
363
+ n_heads (int): Number of attention heads.
364
+ dim (int): Dimension size of the model.
365
+ head_dim (int): Dimension size of each attention head.
366
+ attention (Attention): Attention module.
367
+ feed_forward (FeedForward): FeedForward module.
368
+ layer_id (int): Identifier for the layer.
369
+ attention_norm (RMSNorm): Layer normalization for attention output.
370
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
371
+
372
+ """
373
+ super().__init__()
374
+ self.n_heads = args.n_heads
375
+ self.dim = args.dim
376
+ self.head_dim = args.dim // args.n_heads
377
+ self.attention = Attention(args)
378
+ self.feed_forward = FeedForward(
379
+ dim=args.dim,
380
+ hidden_dim=4 * args.dim,
381
+ multiple_of=args.multiple_of,
382
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
383
+ )
384
+ self.layer_id = layer_id
385
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
386
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
387
+
388
+ def forward(
389
+ self,
390
+ x: torch.Tensor,
391
+ start_pos: int,
392
+ freqs_cis: torch.Tensor,
393
+ mask: Optional[torch.Tensor]
394
+ ):
395
+ """
396
+ Perform a forward pass through the TransformerBlock.
397
+
398
+ Args:
399
+ x (torch.Tensor): Input tensor.
400
+ start_pos (int): Starting position for attention caching.
401
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
402
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
403
+
404
+ Returns:
405
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
406
+
407
+ """
408
+ h = x + self.attention.forward(
409
+ self.attention_norm(x), start_pos, freqs_cis, mask
410
+ )
411
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
412
+ return out
413
+
414
+ class SuperposedTransformer(nn.Module):
415
+ def __init__(self, params: ModelArgs):
416
+ """
417
+ Initialize a Transformer model.
418
+
419
+ Args:
420
+ params (ModelArgs): Model configuration parameters.
421
+
422
+ Attributes:
423
+ params (ModelArgs): Model configuration parameters.
424
+ vocab_size (int): Vocabulary size.
425
+ n_layers (int): Number of layers in the model.
426
+ tok_embeddings (ParallelEmbedding): Token embeddings.
427
+ layers (torch.nn.ModuleList): List of Transformer blocks.
428
+ norm (RMSNorm): Layer normalization for the model output.
429
+ output (ColumnParallelLinear): Linear layer for final output.
430
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
431
+
432
+ """
433
+ super().__init__()
434
+ self.params = params
435
+ self.vocab_size = params.vocab_size
436
+ self.n_layers = params.n_layers
437
+
438
+ self.tok_embeddings = ParallelEmbedding(
439
+ params.vocab_size, params.dim, init_method=lambda x: x
440
+ )
441
+
442
+ self.tok_mixing_embeddings = ColumnParallelLinear(
443
+ params.vocab_size, params.dim, bias=False, init_method=lambda x: x
444
+ ) # dims here are formality (what matters is below)
445
+ self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
446
+
447
+ self.layers = torch.nn.ModuleList()
448
+ for layer_id in range(params.n_layers):
449
+ self.layers.append(MixedTransformerBlock(layer_id, params))
450
+
451
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
452
+ self.output = ColumnParallelLinear(
453
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
454
+ )
455
+
456
+ self.freqs_cis = precompute_freqs_cis(
457
+ # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
458
+ # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
459
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
460
+ )
461
+
462
+ @torch.inference_mode()
463
+ def forward(self,
464
+ tokens: torch.Tensor,
465
+ start_pos: int,
466
+ token_weights: Optional[torch.Tensor],
467
+ verbose: Optional[bool] = False):
468
+ """
469
+ Perform a forward pass through the Transformer model.
470
+
471
+ Args:
472
+ tokens (torch.Tensor): Input token indices.
473
+ start_pos (int): Starting position for attention caching.
474
+ token_weights (torch.Tensor): Superposition matrix.
475
+ verbose (bool): Whether to return intermediate hidden layer states
476
+
477
+ Returns:
478
+ torch.Tensor or (torch.Tensor, Dict): Output logits after applying the Transformer model.
479
+
480
+ """
481
+ if verbose:
482
+ states = {"layers": [], "weights": None}
483
+ _bsz, seqlen = tokens.shape
484
+ if token_weights is not None:
485
+ h = self.tok_mixing_embeddings(token_weights.half()).unsqueeze(1)
486
+ else:
487
+ h = self.tok_embeddings(tokens)
488
+ self.freqs_cis = self.freqs_cis.to(h.device)
489
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
490
+ if verbose:
491
+ states["layers"].append(h)
492
+ states["weights"] = token_weights
493
+
494
+ mask = None
495
+ if seqlen > 1:
496
+ mask = torch.full(
497
+ (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
498
+ )
499
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
500
+
501
+ for layer in self.layers:
502
+ h = layer(h, start_pos, freqs_cis, mask)
503
+ if verbose:
504
+ states["layers"].append(h)
505
+
506
+ h = self.norm(h)
507
+ if verbose:
508
+ states["layers"].append(h)
509
+
510
+ output = self.output(h).float()
511
+
512
+ if verbose:
513
+ return output, states
514
+ else:
515
+ return output
superposed/llama/tokenizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import os
5
+ from logging import getLogger
6
+ from typing import List
7
+
8
+ from sentencepiece import SentencePieceProcessor
9
+
10
+
11
+ logger = getLogger()
12
+
13
+
14
+ class Tokenizer:
15
+ """tokenizing and encoding/decoding text using SentencePiece."""
16
+ def __init__(self, model_path: str):
17
+ """
18
+ Initializes the Tokenizer with a SentencePiece model.
19
+
20
+ Args:
21
+ model_path (str): The path to the SentencePiece model file.
22
+ """
23
+ # reload tokenizer
24
+ assert os.path.isfile(model_path), model_path
25
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
26
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
27
+
28
+ # BOS / EOS token IDs
29
+ self.n_words: int = self.sp_model.vocab_size()
30
+ self.bos_id: int = self.sp_model.bos_id()
31
+ self.eos_id: int = self.sp_model.eos_id()
32
+ self.pad_id: int = self.sp_model.pad_id()
33
+ logger.info(
34
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
35
+ )
36
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
37
+
38
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
39
+ """
40
+ Encodes a string into a list of token IDs.
41
+
42
+ Args:
43
+ s (str): The input string to be encoded.
44
+ bos (bool): Whether to prepend the beginning-of-sequence token.
45
+ eos (bool): Whether to append the end-of-sequence token.
46
+
47
+ Returns:
48
+ List[int]: A list of token IDs.
49
+ """
50
+ assert type(s) is str
51
+ t = self.sp_model.encode(s)
52
+ if bos:
53
+ t = [self.bos_id] + t
54
+ if eos:
55
+ t = t + [self.eos_id]
56
+ return t
57
+
58
+ def decode(self, t: List[int]) -> str:
59
+ """
60
+ Decodes a list of token IDs into a string.
61
+
62
+ Args:
63
+ t (List[int]): The list of token IDs to be decoded.
64
+
65
+ Returns:
66
+ str: The decoded string.
67
+ """
68
+ return self.sp_model.decode(t)
superposed/llama/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def log_prob_to_prob(log_probs, temp=1):
4
+ """
5
+ Convert log probabilities to probability distribution and normalize.
6
+ Args:
7
+ log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size)
8
+ Returns:
9
+ Probability distribution (n_prompts, n_drafts, vocab_size)
10
+ """
11
+ # stability constant
12
+ log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0]
13
+ probs = torch.softmax(log_probs / temp, dim=-1)
14
+ return probs
15
+
16
+ def decode(tokenizer, encoding):
17
+ """
18
+ Decode a list of tokens to a string
19
+ Args:
20
+ tokenizer (Any): Tokenizer
21
+ encoding (torch.Tensor): Encoding
22
+ Returns:
23
+ decoding (str)
24
+ """
25
+ pad_locs = (encoding == -1).nonzero()
26
+ if len(pad_locs > 0):
27
+ encoding = encoding[:pad_locs[0].item()]
28
+ return tokenizer.decode(encoding.to(torch.int32).tolist())
29
+
30
+ def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file):
31
+ """
32
+ Print out generations for debugging.
33
+ Args:
34
+ gens (n_prompts * n_drafts, seq_len): Generations to print
35
+ logprobs (n_prompts * n_drafts): Log probs of each generation
36
+ tokenizer (any): Tokenizer
37
+ n_drafts (int): Number of drafts per prompt
38
+ prompt_len (int): Number of tokens in prompt
39
+ """
40
+ n_prompts, n_drafts, seq_len = gens.shape
41
+ gens = gens.reshape(-1, seq_len)
42
+ logprobs = logprobs.flatten()
43
+ count = 0
44
+ for i in range(len(gens)):
45
+ d = decode(tokenizer, gens[i])
46
+ # first draft of this prompt
47
+ if i % n_drafts == 0:
48
+ count = 0
49
+ print("---------------", file=output_file)
50
+ prompt = decode(tokenizer, gens[i][:prompt_len])
51
+ print(f"prompt: {prompt}", file=output_file)
52
+ print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file)
53
+ count += 1
54
+
55
+ def print_probs(next_probs, tokenizer, output_file):
56
+ """
57
+ Print out next token options and probabilities for debugging
58
+ Args:
59
+ next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size)
60
+ tokenizer (any): Tokenizer
61
+ """
62
+ print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file)
63
+ n_prompts, n_drafts, vocab_size = next_probs.shape
64
+ for p_idx in range(n_prompts):
65
+ print(f"\tPrompt {p_idx}:", file=output_file)
66
+ for d_idx in range(n_drafts):
67
+ next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1)
68
+ print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file)
69
+ print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file)
70
+ print(f"\t\tProbs: {next_token_probs}", file=output_file)
superposed/ngrams/__pycache__/ngram_models.cpython-312.pyc ADDED
Binary file (5.53 kB). View file
 
superposed/ngrams/make_corpus.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import argparse
3
+ import os
4
+ import pickle
5
+ import glob
6
+ import json
7
+ from datasets import load_dataset
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, LlamaTokenizer
10
+ from loguru import logger
11
+
12
+
13
+ def create_corpuses(
14
+ ckpt_path,
15
+ start_doc,
16
+ end_doc,
17
+ dataset,
18
+ tokenizer,
19
+ train_bigram: bool,
20
+ train_trigram: bool,
21
+ train_fourgram: bool,
22
+ train_fivegram: bool,
23
+ train_sixgram: bool,
24
+ train_sevengram: bool
25
+ ):
26
+ bigram_corpus = {}
27
+ trigram_corpus = {}
28
+ fourgram_corpus = {}
29
+ fivegram_corpus = {}
30
+ sixgram_corpus = {}
31
+ sevengram_corpus = {}
32
+
33
+ bigram_corpus_counts = {}
34
+ trigram_corpus_counts = {}
35
+ fourgram_corpus_counts = {}
36
+ fivegram_corpus_counts = {}
37
+ sixgram_corpus_counts = {}
38
+ sevengram_corpus_counts = {}
39
+
40
+ iterations = end_doc - start_doc
41
+ for i in tqdm(range(iterations)):
42
+ t = dataset[start_doc + i]["text"]
43
+ encoded_text = tokenizer.encode(t)
44
+ for start_idx in range(1, len(encoded_text)): # count from first real to eos
45
+ pOne = encoded_text[start_idx-1] if start_idx >= 1 else None
46
+ pTwo = encoded_text[start_idx-2] if start_idx >= 2 else None
47
+ pThree = encoded_text[start_idx-3] if start_idx >= 3 else None
48
+ pFour = encoded_text[start_idx-4] if start_idx >= 4 else None
49
+ pFive = encoded_text[start_idx-5] if start_idx >= 5 else None
50
+ pSix = encoded_text[start_idx-6] if start_idx >= 6 else None
51
+
52
+ token = encoded_text[start_idx]
53
+ # bigram
54
+ if train_bigram and start_idx >= 1:
55
+ prior = pOne
56
+ if prior not in bigram_corpus:
57
+ bigram_corpus[prior] = {}
58
+ bigram_corpus_counts[prior] = 0
59
+ bigram_corpus[prior][token] = bigram_corpus[prior].get(token, 0) + 1
60
+ bigram_corpus_counts[prior] += 1
61
+ # trigram
62
+ if train_trigram and start_idx >= 2:
63
+ prior = (pTwo, pOne)
64
+ if prior not in trigram_corpus:
65
+ trigram_corpus[prior] = {}
66
+ trigram_corpus_counts[prior] = 0
67
+ trigram_corpus[prior][token] = trigram_corpus[prior].get(token, 0) + 1
68
+ trigram_corpus_counts[prior] += 1
69
+ # fourgram
70
+ if train_fourgram and start_idx >= 3:
71
+ prior = (pThree, pTwo, pOne)
72
+ if prior not in fourgram_corpus:
73
+ fourgram_corpus[prior] = {}
74
+ fourgram_corpus_counts[prior] = 0
75
+ fourgram_corpus[prior][token] = fourgram_corpus[prior].get(token, 0) + 1
76
+ fourgram_corpus_counts[prior] += 1
77
+ # fivegram
78
+ if train_fivegram and start_idx >= 4:
79
+ prior = (pFour, pThree, pTwo, pOne)
80
+ if prior not in fivegram_corpus:
81
+ fivegram_corpus[prior] = {}
82
+ fivegram_corpus_counts[prior] = 0
83
+ fivegram_corpus[prior][token] = fivegram_corpus[prior].get(token, 0) + 1
84
+ fivegram_corpus_counts[prior] += 1
85
+ # sixgram
86
+ if train_sixgram and start_idx >= 5:
87
+ prior = (pFive, pFour, pThree, pTwo, pOne)
88
+ if prior not in sixgram_corpus:
89
+ sixgram_corpus[prior] = {}
90
+ sixgram_corpus_counts[prior] = 0
91
+ sixgram_corpus[prior][token] = sixgram_corpus[prior].get(token, 0) + 1
92
+ sixgram_corpus_counts[prior] += 1
93
+ # sevengram
94
+ if train_sevengram and start_idx >= 6:
95
+ prior = (pSix, pFive, pFour, pThree, pTwo, pOne)
96
+ if prior not in sevengram_corpus:
97
+ sevengram_corpus[prior] = {}
98
+ sevengram_corpus_counts[prior] = 0
99
+ sevengram_corpus[prior][token] = sevengram_corpus[prior].get(token, 0) + 1
100
+ sevengram_corpus_counts[prior] += 1
101
+ save_corpus(ckpt_path, bigram_corpus, trigram_corpus, fourgram_corpus, fivegram_corpus, sixgram_corpus, sevengram_corpus, start_doc, end_doc)
102
+ save_counts(ckpt_path, bigram_corpus_counts, trigram_corpus_counts, fourgram_corpus_counts, fivegram_corpus_counts, sixgram_corpus_counts, sevengram_corpus_counts, start_doc, end_doc)
103
+
104
+ def merge_corpus_helper(c1, c2):
105
+ """
106
+ Merge the corpuses c1 and c2, returning the merged result.
107
+ """
108
+ for prior in c2:
109
+ # if share prior
110
+ if prior in c1:
111
+ c1_prior = c1[prior]
112
+ c2_prior = c2[prior]
113
+ for token in c2_prior:
114
+ # if share token
115
+ if token in c1_prior:
116
+ c1_prior[token] += c2_prior[token]
117
+ # else just use c2's
118
+ else:
119
+ c1_prior[token] = c2_prior[token]
120
+ else:
121
+ # else just use c2's
122
+ c1[prior] = c2[prior]
123
+ return c1
124
+
125
+ def merge_counts_helper(c1, c2):
126
+ """
127
+ Merge the count corpuses c1 and c2, returning the merged result.
128
+ """
129
+ for prior in c2:
130
+ if prior in c1:
131
+ c1[prior] += c2[prior]
132
+ else:
133
+ c1[prior] = c2[prior]
134
+ return c1
135
+
136
+ def save_corpus(save_dir, b_d, t_d, fo_d, fi_d, si_d, se_d, start_doc, end_doc):
137
+ """
138
+ Save corpuses b_d (bigram) to se_d (sevengram), where the corpus contains mappings
139
+ {prefix : {next_token1: ct, next_token2: ct, ...}}.
140
+ """
141
+ prefixes = ["b_d", "t_d", "fo_d", "fi_d", "si_d", "se_d"]
142
+ for p, corpus in zip(prefixes, [b_d, t_d, fo_d, fi_d, si_d, se_d]):
143
+ with open(f"{save_dir}/{p}{start_doc}-{end_doc}.pkl", "wb") as f:
144
+ pickle.dump(corpus, f)
145
+
146
+ def save_counts(save_dir, b_ct, t_ct, fo_ct, fi_ct, si_ct, se_ct, start_doc, end_doc):
147
+ """
148
+ Save count corpuses b_ct (bigram) to se_ct (sevengram), where each count
149
+ corpus contains mappings {prefix : total}.
150
+ """
151
+ prefixes = ["b_ct", "t_ct", "fo_ct", "fi_ct", "si_ct", "se_ct"]
152
+ for p, corpus in zip(prefixes, [b_ct, t_ct, fo_ct, fi_ct, si_ct, se_ct]):
153
+ with open(f"{save_dir}/{p}{start_doc}-{end_doc}.pkl", "wb") as f:
154
+ pickle.dump(corpus, f)
155
+
156
+ def merge_corpuses(ckpt_path):
157
+ """
158
+ Helper to merge corpuses in `ckpt_path`, where `ckpt_path` might contain
159
+ multiple bigram, trigram, etc. corpuses from each process.
160
+ """
161
+ prefixes = ["b_d", "t_d", "fo_d", "fi_d", "si_d", "se_d"]
162
+ for prefix in prefixes:
163
+ if os.path.exists(f"{ckpt_path}/{prefix}_final.pkl"):
164
+ os.remove(f"{ckpt_path}/{prefix}_final.pkl")
165
+ corpus = None
166
+ for filepath in glob.glob(f"{ckpt_path}/{prefix}*"):
167
+ with open(filepath, "rb") as f:
168
+ current = pickle.load(f)
169
+ if corpus is None:
170
+ corpus = current
171
+ else:
172
+ corpus = merge_corpus_helper(corpus, current)
173
+ os.remove(filepath)
174
+ with open(f"{ckpt_path}/{prefix}_final.pkl", "wb") as f:
175
+ pickle.dump(corpus, f)
176
+
177
+ def merge_counts(ckpt_path):
178
+ """
179
+ Helper to merge count corpuses in `ckpt_path`, where `ckpt_path` might contain
180
+ multiple bigram, trigram, etc. count corpuses from each process.
181
+ """
182
+ prefixes = ["b_ct", "t_ct", "fo_ct", "fi_ct", "si_ct", "se_ct"]
183
+ for prefix in prefixes:
184
+ if os.path.exists(f"{ckpt_path}/{prefix}_final.pkl"):
185
+ os.remove(f"{ckpt_path}/{prefix}_final.pkl")
186
+
187
+ counts = None
188
+ for filepath in glob.glob(f"{ckpt_path}/{prefix}*"):
189
+ with open(filepath, "rb") as f:
190
+ current = pickle.load(f)
191
+ if counts is None:
192
+ counts = current
193
+ else:
194
+ counts = merge_counts_helper(counts, current)
195
+ os.remove(filepath)
196
+ with open(f"{ckpt_path}/{prefix}_final.pkl", "wb") as f:
197
+ pickle.dump(counts, f)
198
+
199
+
200
+ if __name__ == "__main__":
201
+ # Input arguments
202
+ parser = argparse.ArgumentParser()
203
+ parser.add_argument("ckpt_path", type=str, help="Path to store ngram models")
204
+ parser.add_argument("start_doc", type=str, help="# of first document")
205
+ parser.add_argument("end_doc", type=str, help="# of last document")
206
+ parser.add_argument("c", type=int, help="number of processes")
207
+ parser.add_argument("--tok_name", type=str, help="name of HF tokenizer, or llama", default="llama")
208
+ for arg_name in ["--bigram", "--trigram", "--fourgram", "--fivegram", "--sixgram", "--sevengram"]:
209
+ parser.add_argument(arg_name, type=str, help=f"Whether to make a {arg_name} model")
210
+ parser.add_argument("--dset_name", type=str, help="name of HF dataset")
211
+ parser.add_argument("--dset_path", type=str, help="path to dataset")
212
+ # Parse arguments
213
+ args = parser.parse_args()
214
+ start_doc_ovr = int(args.start_doc)
215
+ end_doc_ovr = int(args.end_doc)
216
+ n_cores = args.c
217
+ tok_name = args.tok_name
218
+ ckpt_path = args.ckpt_path
219
+ dset_name = args.dset_name
220
+ dset_path = args.dset_path
221
+ if not dset_name and not dset_path:
222
+ raise RuntimeError("Please provide a dataset")
223
+ if not os.path.exists(ckpt_path):
224
+ os.makedirs(ckpt_path)
225
+ logger.info(f"{start_doc_ovr} {end_doc_ovr} {n_cores}")
226
+
227
+ # Load dataset and tokenizer
228
+ if dset_name:
229
+ ds = load_dataset(dset_name, cache_dir="../../../datasets/")["train"].shuffle(seed=42)
230
+ else:
231
+ with open(dset_path, "r") as f:
232
+ ds = json.load(f)["train"]
233
+ if tok_name == "llama":
234
+ # REPLACE WITH YOUR OWN PATH
235
+ tokenizer = LlamaTokenizer.from_pretrained("../../7B_HF", add_bos_token=False)
236
+ else:
237
+ tokenizer = AutoTokenizer.from_pretrained(tok_name)
238
+
239
+ # Start running
240
+ num_processes = n_cores
241
+ total_docs = end_doc_ovr - start_doc_ovr
242
+ docs_per_c = (total_docs) // num_processes
243
+ processes = []
244
+ for core in range(n_cores):
245
+ start_doc = core * docs_per_c # relative start doc
246
+ end_doc = (core + 1) * docs_per_c if core < n_cores - 1 else total_docs # relative end doc
247
+ logger.info(f"Starting core {core} from document {start_doc} to {end_doc}")
248
+ process = multiprocessing.Process(target=create_corpuses,
249
+ args=(ckpt_path,
250
+ start_doc_ovr + start_doc,
251
+ start_doc_ovr + end_doc,
252
+ ds, tokenizer,
253
+ args.bigram,
254
+ args.trigram,
255
+ args.fourgram,
256
+ args.fivegram,
257
+ args.sixgram,
258
+ args.sevengram))
259
+ processes.append(process)
260
+ process.start()
261
+ for process in processes:
262
+ process.join()
263
+ logger.info("Finished Saving")
264
+ logger.info("Merging...")
265
+ merge_corpuses(ckpt_path)
266
+ merge_counts(ckpt_path)
267
+ logger.info("Merged.")
268
+
superposed/ngrams/ngram_models.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import sys
3
+
4
+ import torch
5
+
6
+ class NGram():
7
+ def __init__(self, corpus, corpus_counts, type):
8
+ self.corpus = corpus
9
+ self.counts = corpus_counts
10
+ self.type = type
11
+
12
+ def prob(self, key, next):
13
+ """
14
+ Args:
15
+ key (tuple): tuple of token ID's forming prior
16
+ next (int): probability of next token
17
+ """
18
+ l = len(key)
19
+ if self.type == "bigram":
20
+ assert l == 1
21
+ key = key[0]
22
+ elif self.type == "trigram":
23
+ assert l == 2
24
+ elif self.type == "fourgram":
25
+ assert l == 3
26
+ elif self.type == "fivegram":
27
+ assert l == 4
28
+ elif self.type == "sixgram":
29
+ assert l == 5
30
+ elif self.type == "sevengram":
31
+ assert l == 6
32
+
33
+ count = 0
34
+ if key in self.corpus:
35
+ count = self.corpus[key].get(next, 0)
36
+ total = sum(self.corpus[key].values())
37
+ return count / total
38
+ else:
39
+ return -1
40
+
41
+ def ntd(self, key, vocab_size=32000):
42
+ """
43
+ Args:
44
+ key (tuple): tuple of token ID's forming prior
45
+ Returns:
46
+ prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities
47
+ """
48
+ if key in self.corpus:
49
+ prob_tensor = torch.zeros(vocab_size)
50
+ total = sum(self.corpus[key].values())
51
+ for next_token in self.corpus[key]:
52
+ prob_tensor[next_token] = self.corpus[key][next_token] / total
53
+ return prob_tensor
54
+ else:
55
+ return None
56
+
57
+ def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
58
+ """
59
+ Loads and returns a list correspoding to bigram to sevengram models, containing
60
+ the models that whose parameters are `True`. See below for expected corpus names.
61
+ Args:
62
+ ckpt_path (str): Location of ngram models
63
+ bigram-sevengram: Which models to load
64
+ Returns:
65
+ List of n-gram models
66
+ """
67
+ models = []
68
+ if bigram:
69
+ print("Making bigram...")
70
+ with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
71
+ bigram = pickle.load(f)
72
+ bigram_model = NGram(bigram, None, "bigram")
73
+ models.append(bigram_model)
74
+ print(sys.getsizeof(bigram))
75
+
76
+ if trigram:
77
+ print("Making trigram...")
78
+ with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
79
+ trigram = pickle.load(f)
80
+ trigram_model = NGram(trigram, None, "trigram")
81
+ models.append(trigram_model)
82
+ print(sys.getsizeof(trigram))
83
+
84
+ if fourgram:
85
+ print("Making fourgram...")
86
+ with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
87
+ fourgram = pickle.load(f)
88
+ fourgram_model = NGram(fourgram, None, "fourgram")
89
+ models.append(fourgram_model)
90
+ print(sys.getsizeof(fourgram))
91
+
92
+ if fivegram:
93
+ print("Making fivegram...")
94
+ with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
95
+ fivegram = pickle.load(f)
96
+ fivegram_model = NGram(fivegram, None, "fivegram")
97
+ models.append(fivegram_model)
98
+ print(sys.getsizeof(fivegram))
99
+
100
+ if sixgram:
101
+ print("Making sixgram...")
102
+ with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
103
+ sixgram = pickle.load(f)
104
+ sixgram_model = NGram(sixgram, None, "sixgram")
105
+ models.append(sixgram_model)
106
+ print(sys.getsizeof(sixgram))
107
+
108
+ if sevengram:
109
+ print("Making sevengram...")
110
+ with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
111
+ sevengram = pickle.load(f)
112
+ sevengram_model = NGram(sevengram, None, "sevengram")
113
+ models.append(sevengram_model)
114
+
115
+ return models
superposed/ngrams/test.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": [
3
+ {"text": "Hi my name is"},
4
+ {"text": "This is a story of"},
5
+ {"text": "In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying"},
6
+ {"text": "There is one class of AutoModel for each task, and for each backend (PyTorch, TensorFlow, or Flax)."}
7
+ ]
8
+ }
superposed/notebooks/custom.ipynb ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "119805f4-8589-4379-ad87-a7bad4c0e658",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n",
15
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcWriteOptions size changed, may indicate binary incompatibility. Expected 72 from C header, got 88 from PyObject\n",
16
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 96 from C header, got 104 from PyObject\n",
17
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileInfo size changed, may indicate binary incompatibility. Expected 64 from C header, got 88 from PyObject\n",
18
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileSelector size changed, may indicate binary incompatibility. Expected 48 from C header, got 72 from PyObject\n",
19
+ "2024-05-30 03:09:58.230601: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
20
+ "2024-05-30 03:09:58.280835: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
21
+ "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
22
+ "2024-05-30 03:10:03.250651: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
23
+ ]
24
+ }
25
+ ],
26
+ "source": [
27
+ "%load_ext autoreload\n",
28
+ "%autoreload 2\n",
29
+ "\n",
30
+ "import json\n",
31
+ "import os\n",
32
+ "import pickle\n",
33
+ "from datetime import datetime\n",
34
+ "\n",
35
+ "import evaluate\n",
36
+ "import torch\n",
37
+ "from tqdm import tqdm\n",
38
+ "\n",
39
+ "from eval import *\n",
40
+ "from superposed.llama.metrics import *\n",
41
+ "from superposed.llama.generation import Llama\n",
42
+ "from superposed.llama.superposed_generation import SuperposedLlama\n",
43
+ "from superposed.llama.tokenizer import Tokenizer\n",
44
+ "from superposed.ngrams.ngram_models import make_models"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "id": "51c15900-c8b8-46d9-a884-6842a391ef48",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "sup_device = torch.device(\"cuda:0\")\n",
55
+ "tokenizer = Tokenizer('../../7B/tokenizer.model')"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 5,
61
+ "id": "9817d9a4-ad64-41c6-b87b-b1e422b836a9",
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "# Params\n",
74
+ "param_file = \"../../params/p15_d3_mixed.json\"\n",
75
+ "with open(param_file, \"r\") as f:\n",
76
+ " params = json.load(f)\n",
77
+ " print(f\"Parameters: {params}\")\n",
78
+ "alpha = params[\"alpha\"]\n",
79
+ "temp = params[\"temp\"]\n",
80
+ "n_drafts = params[\"n_drafts\"]\n",
81
+ "prompt_len = params[\"prompt_len\"]\n",
82
+ "n_token_sample = params[\"n_token_sample\"]\n",
83
+ "i_weights = params[\"i_weights\"]\n",
84
+ "i_length = params[\"i_length\"]"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 6,
90
+ "id": "9c99098e-a38b-4c78-a0e9-8c80309830bb",
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "name": "stdout",
95
+ "output_type": "stream",
96
+ "text": [
97
+ "Making bigram...\n",
98
+ "1310800\n",
99
+ "Making trigram...\n",
100
+ "671088728\n",
101
+ "Making fourgram...\n",
102
+ "2684354648\n",
103
+ "Making fivegram...\n",
104
+ "5368709200\n",
105
+ "Making sixgram...\n",
106
+ "5368709200\n"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "# Create ngram models\n",
112
+ "ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 7,
118
+ "id": "c3331332-242c-4e98-9f11-58c6dc0ef581",
119
+ "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "> initializing model parallel with size 1\n",
126
+ "> initializing ddp with size 1\n",
127
+ "> initializing pipeline with size 1\n"
128
+ ]
129
+ },
130
+ {
131
+ "name": "stderr",
132
+ "output_type": "stream",
133
+ "text": [
134
+ "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
135
+ " _C._set_default_tensor_type(t)\n"
136
+ ]
137
+ },
138
+ {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "Loaded in 25.15 seconds\n",
143
+ "cuda:0\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "weight_path = \"../../7B/\"\n",
149
+ "model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
150
+ " tokenizer_path=f'{weight_path}/tokenizer.model', \n",
151
+ " max_seq_len=100, \n",
152
+ " max_batch_size=32,\n",
153
+ " device=sup_device,\n",
154
+ " model_parallel_size=1)"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "id": "e2b48c23-d6a3-43b1-ad4c-54524aacfda6",
160
+ "metadata": {},
161
+ "source": [
162
+ "# Inference"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 11,
168
+ "id": "5093373b-bf76-47e3-8f99-1045b60f29c3",
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "def decode(tokenizer, encoding):\n",
173
+ " \"\"\"\n",
174
+ " Args:\n",
175
+ " tokenizer (Any): Tokenizer\n",
176
+ " encoding (torch.Tensor): Encoding\n",
177
+ " Returns:\n",
178
+ " decoding (str)\n",
179
+ " \"\"\"\n",
180
+ " eos_locs = (encoding == tokenizer.eos_id).nonzero()\n",
181
+ " if len(eos_locs > 0):\n",
182
+ " encoding = encoding[:eos_locs[0]]\n",
183
+ " return tokenizer.decode(encoding.to(torch.int32).tolist())"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 22,
189
+ "id": "18703b19-f3e9-46e4-ab1c-c6d3b403c6d2",
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "prompts = [\n",
194
+ " \"Hi my name is\",\n",
195
+ " \"The Seattle Seahawks were Super Bowl\",\n",
196
+ " \"Penguins are birds native to\"\n",
197
+ "]\n",
198
+ "tokenized_prompts = tokenizer.encode(prompts, True, False)"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 23,
204
+ "id": "d39cd735-9480-4979-ac92-bbd470f75570",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, \n",
209
+ " smoothing=\"geom\",\n",
210
+ " max_gen_len=10, \n",
211
+ " n_token_sample=n_token_sample,\n",
212
+ " alpha=alpha, \n",
213
+ " temp=temp,\n",
214
+ " n_drafts=n_drafts,\n",
215
+ " i_weights=i_weights,\n",
216
+ " i_length=i_length,\n",
217
+ " ngrams=ngrams,\n",
218
+ " get_time=False,\n",
219
+ " penalty=200)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 24,
225
+ "id": "cfefa793-e49e-483a-a504-5cc9e23f619d",
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "gens = alive_gens[0].reshape(len(prompts) * n_drafts, -1)"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 25,
235
+ "id": "5abf87ab-2ee0-4204-868b-1215abf0c8aa",
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "name": "stdout",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "Hi\n",
243
+ "my name\n",
244
+ "is L\n",
245
+ "inda,\n",
246
+ "I am\n",
247
+ "a \n",
248
+ "40\n",
249
+ "year old\n",
250
+ "woman who\n"
251
+ ]
252
+ }
253
+ ],
254
+ "source": [
255
+ "for i in gens:\n",
256
+ " print(decode(tokenizer, i))"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "e73dc3cc-baa5-468d-bdd1-827465bdeb62",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": []
266
+ }
267
+ ],
268
+ "metadata": {
269
+ "kernelspec": {
270
+ "display_name": "Python 3 (ipykernel)",
271
+ "language": "python",
272
+ "name": "python3"
273
+ },
274
+ "language_info": {
275
+ "codemirror_mode": {
276
+ "name": "ipython",
277
+ "version": 3
278
+ },
279
+ "file_extension": ".py",
280
+ "mimetype": "text/x-python",
281
+ "name": "python",
282
+ "nbconvert_exporter": "python",
283
+ "pygments_lexer": "ipython3",
284
+ "version": "3.11.5"
285
+ }
286
+ },
287
+ "nbformat": 4,
288
+ "nbformat_minor": 5
289
+ }
superposed/notebooks/nq.ipynb ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "The autoreload extension is already loaded. To reload it, use:\n",
13
+ " %reload_ext autoreload\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "%load_ext autoreload\n",
19
+ "%autoreload 2\n",
20
+ "\n",
21
+ "import json\n",
22
+ "import os\n",
23
+ "import re\n",
24
+ "from datetime import datetime\n",
25
+ "\n",
26
+ "import torch\n",
27
+ "from datasets import load_dataset\n",
28
+ "from tqdm import tqdm\n",
29
+ "\n",
30
+ "from eval import *\n",
31
+ "from superposed.llama.metrics import *\n",
32
+ "from superposed.llama.generation import Llama\n",
33
+ "from superposed.llama.superposed_generation import SuperposedLlama\n",
34
+ "from superposed.llama.tokenizer import Tokenizer\n",
35
+ "from superposed.ngrams.ngram_models import make_models"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {},
41
+ "source": [
42
+ "# Setup"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 3,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "nq = load_dataset(\"nq_open\")[\"validation\"]"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 6,
57
+ "metadata": {},
58
+ "outputs": [
59
+ {
60
+ "name": "stdout",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
64
+ ]
65
+ }
66
+ ],
67
+ "source": [
68
+ "# Params\n",
69
+ "param_file = \"../../params/p15_d3_mixed.json\"\n",
70
+ "with open(param_file, \"r\") as f:\n",
71
+ " params = json.load(f)\n",
72
+ " print(f\"Parameters: {params}\")\n",
73
+ "alpha = params[\"alpha\"]\n",
74
+ "temp = params[\"temp\"]\n",
75
+ "n_drafts = params[\"n_drafts\"]\n",
76
+ "prompt_len = params[\"prompt_len\"]\n",
77
+ "n_token_sample = params[\"n_token_sample\"]\n",
78
+ "i_weights = params[\"i_weights\"]\n",
79
+ "i_length = params[\"i_length\"]"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {},
85
+ "source": [
86
+ "# Create Models"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 7,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "Making bigram...\n",
99
+ "1310800\n",
100
+ "Making trigram...\n",
101
+ "671088728\n",
102
+ "Making fourgram...\n",
103
+ "2684354648\n",
104
+ "Making fivegram...\n",
105
+ "5368709200\n",
106
+ "Making sixgram...\n",
107
+ "5368709200\n"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 9,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "sup_device = torch.device(\"cuda:0\")\n",
122
+ "reg_device = torch.device(\"cuda:1\")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 11,
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "name": "stdout",
132
+ "output_type": "stream",
133
+ "text": [
134
+ "> initializing model parallel with size 1\n",
135
+ "> initializing ddp with size 1\n",
136
+ "> initializing pipeline with size 1\n"
137
+ ]
138
+ },
139
+ {
140
+ "name": "stderr",
141
+ "output_type": "stream",
142
+ "text": [
143
+ "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
144
+ " _C._set_default_tensor_type(t)\n"
145
+ ]
146
+ },
147
+ {
148
+ "name": "stdout",
149
+ "output_type": "stream",
150
+ "text": [
151
+ "Loaded in 33.68 seconds\n",
152
+ "cuda:0\n"
153
+ ]
154
+ }
155
+ ],
156
+ "source": [
157
+ "# load superposed\n",
158
+ "weight_path = \"../../7B/\"\n",
159
+ "sup_model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
160
+ " tokenizer_path=f'{weight_path}/tokenizer.model', \n",
161
+ " max_seq_len=1000, \n",
162
+ " max_batch_size=16,\n",
163
+ " device=sup_device,\n",
164
+ " model_parallel_size=1)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 12,
170
+ "metadata": {},
171
+ "outputs": [
172
+ {
173
+ "name": "stdout",
174
+ "output_type": "stream",
175
+ "text": [
176
+ "0\n",
177
+ "Loaded in 22.47 seconds\n"
178
+ ]
179
+ }
180
+ ],
181
+ "source": [
182
+ "# load regular\n",
183
+ "reg_model = Llama.build(ckpt_dir=weight_path, \n",
184
+ " tokenizer_path=f'{weight_path}/tokenizer.model', \n",
185
+ " max_seq_len=1000, \n",
186
+ " max_batch_size=16,\n",
187
+ " device=reg_device, # reg_device,\n",
188
+ " model_parallel_size=1)"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 13,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "tokenizer = Tokenizer(f\"{weight_path}/tokenizer.model\")"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "# Evaluation"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 14,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "model_types = [\"greedy\", \"superposed\", \"regular\"]\n",
214
+ "model_type = model_types[1]"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 17,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "def evaluate_nq(model_type, question, max_gen_len):\n",
224
+ " question = \"Answer these questions:\\n\\nQ: \" + question + \"?\\nA:\"\n",
225
+ " text_len = len(question) # for truncating\n",
226
+ " prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
227
+ " if model_type == \"regular\" or model_type == \"greedy\":\n",
228
+ " if model_type == \"regular\":\n",
229
+ " input = [question for _ in range(n_drafts)]\n",
230
+ " print(input)\n",
231
+ " sequences, _ = evaluate_nucleus_losses(data=input,\n",
232
+ " model=reg_model,\n",
233
+ " tokenizer=tokenizer,\n",
234
+ " prompt_len=prompt_len,\n",
235
+ " max_gen_len=max_gen_len,\n",
236
+ " temp=0.6,\n",
237
+ " bsz=8,\n",
238
+ " marker=False)\n",
239
+ " else:\n",
240
+ " sequences, _ = evaluate_nucleus_losses(data=[question],\n",
241
+ " model=reg_model,\n",
242
+ " tokenizer=tokenizer,\n",
243
+ " prompt_len=prompt_len,\n",
244
+ " max_gen_len=max_gen_len,\n",
245
+ " temp=0,\n",
246
+ " bsz=8,\n",
247
+ " marker=False)\n",
248
+ " n_pd, seq_len = sequences.shape\n",
249
+ " elif model_type == \"superposed\":\n",
250
+ " sequences, _ = evaluate_mixed_losses(data=[question],\n",
251
+ " model=sup_model,\n",
252
+ " tokenizer=tokenizer,\n",
253
+ " prompt_len=prompt_len,\n",
254
+ " max_gen_len=max_gen_len,\n",
255
+ " alpha=alpha,\n",
256
+ " temp=temp,\n",
257
+ " n_drafts=n_drafts,\n",
258
+ " n_token_sample=n_token_sample,\n",
259
+ " smoothing=None, # Use greedy\n",
260
+ " bsz=8,\n",
261
+ " i_weights=i_weights,\n",
262
+ " i_length=i_length,\n",
263
+ " ngrams=ngrams,\n",
264
+ " marker=False)\n",
265
+ " n_p, n_d, seq_len = sequences.shape\n",
266
+ " # Process results\n",
267
+ " sequences = sequences.reshape(-1, seq_len).tolist()\n",
268
+ " for d_idx in range(len(sequences)):\n",
269
+ " draft = sequences[d_idx]\n",
270
+ " if -1 in draft:\n",
271
+ " draft = draft[:draft.index(-1)]\n",
272
+ " sequences[d_idx] = draft\n",
273
+ " decoded_seq = tokenizer.decode(sequences)\n",
274
+ " answers = []\n",
275
+ " for s in decoded_seq:\n",
276
+ " answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
277
+ " return answers\n",
278
+ " "
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "# Run evaluation\n",
288
+ "predictions = []\n",
289
+ "print(f\"Precision from 1 to {n_drafts}\")\n",
290
+ "for sample in tqdm(nq):\n",
291
+ " # Adaptively determine max generation length\n",
292
+ " longest = 0\n",
293
+ " shortest = 1000\n",
294
+ " for answer in sample[\"answer\"]:\n",
295
+ " tmp = tokenizer.encode([answer], False, False)[0]\n",
296
+ " if len(tmp) > longest:\n",
297
+ " longest = len(tmp)\n",
298
+ " if len(tmp) < shortest:\n",
299
+ " shortest = len(tmp)\n",
300
+ " question = sample[\"question\"]\n",
301
+ " answer = evaluate_nq(model_type, question, max_gen_len=shortest+3)\n",
302
+ " predictions.append({\"question\": question, \"answer\": answer})"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": 52,
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "# Separate results into precisions\n",
312
+ "precisions = {}\n",
313
+ "for i in range(1, n_drafts+1):\n",
314
+ " prec = str(i)\n",
315
+ " responses = []\n",
316
+ " for result in predictions:\n",
317
+ " responses.append({\"question\": result[\"question\"], \"answer\": result[\"answer\"][:i]})\n",
318
+ " precisions[prec] = responses"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 53,
324
+ "metadata": {},
325
+ "outputs": [
326
+ {
327
+ "name": "stdout",
328
+ "output_type": "stream",
329
+ "text": [
330
+ "{'question': 'when was the last time anyone was on the moon', 'answer': ['2019', '2019', '2019-', '2019-', '1019']}\n",
331
+ "================\n",
332
+ "{'question': \"who wrote he ain't heavy he's my brother lyrics\", 'answer': ['The song was written by', 'The lyr was written by', 'The Hol was written by', 'Neil song was written by', 'Neil lyr was written by']}\n",
333
+ "================\n",
334
+ "{'question': 'how many seasons of the bastard executioner are there', 'answer': ['1', 'There1', 'there1', '1', 'There1']}\n",
335
+ "================\n",
336
+ "{'question': 'when did the eagles win last super bowl', 'answer': ['2018', 'The2018', '1018', '2017', 'the2018']}\n",
337
+ "================\n",
338
+ "{'question': \"who won last year's ncaa women's basketball\", 'answer': ['the university of connecticut', 'The university of connecticut', 'university of connecticut', 'the University of connecticut', 'The University of connecticut']}\n",
339
+ "================\n",
340
+ "{'question': 'when did the isle of wight become an island', 'answer': ['1207', 'when1207', '1287', '1277', 'when1287']}\n",
341
+ "================\n",
342
+ "{'question': 'love yourself by justin bieber is about who', 'answer': ['love yourself by justin b', 'love yourself is justin b', 'Justin yourself by justin b', 'Justin yourself is justin b', 'It yourself by justin b']}\n",
343
+ "================\n",
344
+ "{'question': 'who was the ruler of england in 1616', 'answer': ['James I', 'James I of', 'King I', 'j I', 'James I']}\n",
345
+ "================\n",
346
+ "{'question': 'what is the hot coffee mod in san andreas', 'answer': ['The Hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a mod for Grand', 'The hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a modification that Grand', 'It Hot Coffee mod is a modification for Grand']}\n",
347
+ "================\n",
348
+ "{'question': 'what is the maximum data rate for the 802.11a standard select one', 'answer': ['54 Mbps', '54Mbps', '54 mbps', '54 Mbps', '54 Mbps']}\n",
349
+ "================\n"
350
+ ]
351
+ }
352
+ ],
353
+ "source": [
354
+ "# Print some results\n",
355
+ "counter = 0\n",
356
+ "for k in predictions:\n",
357
+ " if counter >= 10:\n",
358
+ " break\n",
359
+ " print(k)\n",
360
+ " counter += 1\n",
361
+ " print(\"================\")"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "metadata": {},
367
+ "source": [
368
+ "# Saving"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": 54,
374
+ "metadata": {},
375
+ "outputs": [
376
+ {
377
+ "name": "stdout",
378
+ "output_type": "stream",
379
+ "text": [
380
+ "dict_keys(['1', '2', '3', '4', '5'])\n"
381
+ ]
382
+ }
383
+ ],
384
+ "source": [
385
+ "# Save results\n",
386
+ "os.makedirs(\"../../nq/\", exist_ok=True)\n",
387
+ "print(precisions.keys())\n",
388
+ "for prec in range(1, n_drafts+1):\n",
389
+ " out_path = f\"../nq/eval_{model_type}_{prec}_test.jsonl\"\n",
390
+ " with open(out_path, \"w\") as f:\n",
391
+ " for obj in precisions[str(prec)]: \n",
392
+ " f.write(json.dumps(obj) + \"\\n\")"
393
+ ]
394
+ }
395
+ ],
396
+ "metadata": {
397
+ "kernelspec": {
398
+ "display_name": "Python 3 (ipykernel)",
399
+ "language": "python",
400
+ "name": "python3"
401
+ },
402
+ "language_info": {
403
+ "codemirror_mode": {
404
+ "name": "ipython",
405
+ "version": 3
406
+ },
407
+ "file_extension": ".py",
408
+ "mimetype": "text/x-python",
409
+ "name": "python",
410
+ "nbconvert_exporter": "python",
411
+ "pygments_lexer": "ipython3",
412
+ "version": "3.11.5"
413
+ }
414
+ },
415
+ "nbformat": 4,
416
+ "nbformat_minor": 4
417
+ }
superposed/notebooks/triviaqa.ipynb ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcWriteOptions size changed, may indicate binary incompatibility. Expected 72 from C header, got 88 from PyObject\n",
15
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 96 from C header, got 104 from PyObject\n",
16
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileInfo size changed, may indicate binary incompatibility. Expected 64 from C header, got 88 from PyObject\n",
17
+ "<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileSelector size changed, may indicate binary incompatibility. Expected 48 from C header, got 72 from PyObject\n",
18
+ "2024-05-30 01:35:17.813978: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
19
+ "2024-05-30 01:35:20.452213: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
20
+ "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
21
+ "2024-05-30 01:35:41.833487: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "%load_ext autoreload\n",
27
+ "%autoreload 2\n",
28
+ "\n",
29
+ "import copy\n",
30
+ "import json\n",
31
+ "import pickle\n",
32
+ "import os\n",
33
+ "import random\n",
34
+ "import re\n",
35
+ "import string\n",
36
+ "import math\n",
37
+ "from datetime import datetime\n",
38
+ "\n",
39
+ "import evaluate\n",
40
+ "import torch\n",
41
+ "import numpy as np\n",
42
+ "from datasets import load_dataset\n",
43
+ "from transformers import LlamaTokenizer\n",
44
+ "from tqdm import tqdm\n",
45
+ "\n",
46
+ "from eval import *\n",
47
+ "from superposed.llama.metrics import *\n",
48
+ "from superposed.llama.generation import Llama\n",
49
+ "from superposed.llama.superposed_generation import SuperposedLlama\n",
50
+ "from superposed.llama.tokenizer import Tokenizer\n",
51
+ "from superposed.ngrams.ngram_models import make_models"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "metadata": {},
57
+ "source": [
58
+ "# Setup"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 3,
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "name": "stdout",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "# Params\n",
76
+ "param_file = \"../../params/p15_d3_mixed.json\"\n",
77
+ "with open(param_file, \"r\") as f:\n",
78
+ " params = json.load(f)\n",
79
+ " print(f\"Parameters: {params}\")\n",
80
+ "alpha = params[\"alpha\"]\n",
81
+ "temp = params[\"temp\"]\n",
82
+ "n_drafts = params[\"n_drafts\"]\n",
83
+ "prompt_len = params[\"prompt_len\"]\n",
84
+ "n_token_sample = params[\"n_token_sample\"]\n",
85
+ "i_weights = params[\"i_weights\"]\n",
86
+ "i_length = params[\"i_length\"]"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 5,
92
+ "metadata": {
93
+ "scrolled": true
94
+ },
95
+ "outputs": [
96
+ {
97
+ "name": "stdout",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Making bigram...\n",
101
+ "1310800\n",
102
+ "Making trigram...\n",
103
+ "671088728\n",
104
+ "Making fourgram...\n",
105
+ "2684354648\n",
106
+ "Making fivegram...\n",
107
+ "5368709200\n",
108
+ "Making sixgram...\n",
109
+ "5368709200\n"
110
+ ]
111
+ }
112
+ ],
113
+ "source": [
114
+ "ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 10,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "sup_device = torch.device(\"cuda:0\")\n",
124
+ "reg_device = torch.device(\"cuda:1\")"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 11,
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "name": "stdout",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "> initializing model parallel with size 1\n",
137
+ "> initializing ddp with size 1\n",
138
+ "> initializing pipeline with size 1\n"
139
+ ]
140
+ },
141
+ {
142
+ "name": "stderr",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
146
+ " _C._set_default_tensor_type(t)\n"
147
+ ]
148
+ },
149
+ {
150
+ "name": "stdout",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Loaded in 22.07 seconds\n",
154
+ "cuda:0\n"
155
+ ]
156
+ }
157
+ ],
158
+ "source": [
159
+ "weight_path = \"../../7B/\"\n",
160
+ "sup_model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
161
+ " tokenizer_path=f'{weight_path}/tokenizer.model', \n",
162
+ " max_seq_len=1000, \n",
163
+ " max_batch_size=16,\n",
164
+ " device=sup_device,\n",
165
+ " model_parallel_size=1)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 12,
171
+ "metadata": {},
172
+ "outputs": [
173
+ {
174
+ "name": "stdout",
175
+ "output_type": "stream",
176
+ "text": [
177
+ "0\n",
178
+ "Loaded in 22.76 seconds\n"
179
+ ]
180
+ }
181
+ ],
182
+ "source": [
183
+ "reg_model = Llama.build(ckpt_dir=weight_path, \n",
184
+ " tokenizer_path=f'{weight_path}/tokenizer.model', \n",
185
+ " max_seq_len=1000, \n",
186
+ " max_batch_size=16,\n",
187
+ " device=reg_device,\n",
188
+ " model_parallel_size=1)"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 18,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "tokenizer = Tokenizer(f\"{weight_path}/tokenizer.model\")"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "# Evaluation"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 13,
210
+ "metadata": {},
211
+ "outputs": [
212
+ {
213
+ "name": "stdout",
214
+ "output_type": "stream",
215
+ "text": [
216
+ "Length: 7993\n"
217
+ ]
218
+ }
219
+ ],
220
+ "source": [
221
+ "trivia_path = \"../../../datasets/qa/wikipedia-dev.json\"\n",
222
+ "with open(trivia_path, \"r\") as f:\n",
223
+ " triviaqa = json.load(f)[\"Data\"]\n",
224
+ "print(f\"Length: {len(triviaqa)}\")"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 14,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "torch.set_default_dtype(torch.float32)"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 15,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "model_types = [\"superposed\", \"regular\"]\n",
243
+ "model_type = model_types[0]"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 16,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/triviaqa/default.yaml\n",
253
+ "def evaluate_trivia(model_type, question, max_gen_len):\n",
254
+ " question = \"Question: \" + question + \"\\nAnswer:\"\n",
255
+ " text_len = len(question) # for truncating\n",
256
+ " prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
257
+ " if model_type == \"regular\":\n",
258
+ " input = [question for _ in range(n_drafts)]\n",
259
+ " sequences, _ = evaluate_nucleus_losses(data=input,\n",
260
+ " model=reg_model,\n",
261
+ " tokenizer=tokenizer,\n",
262
+ " prompt_len=prompt_len,\n",
263
+ " max_gen_len=max_gen_len,\n",
264
+ " temp=0.6, # Set to 0 for greedy\n",
265
+ " bsz=8,\n",
266
+ " marker=False)\n",
267
+ " n_pd, seq_len = sequences.shape\n",
268
+ " elif model_type == \"superposed\":\n",
269
+ " sequences, _ = evaluate_mixed_losses(data=[question],\n",
270
+ " model=sup_model,\n",
271
+ " tokenizer=tokenizer,\n",
272
+ " prompt_len=prompt_len,\n",
273
+ " max_gen_len=max_gen_len,\n",
274
+ " alpha=alpha,\n",
275
+ " temp=temp,\n",
276
+ " n_drafts=n_drafts,\n",
277
+ " n_token_sample=n_token_sample,\n",
278
+ " smoothing=None, # greedy\n",
279
+ " bsz=8,\n",
280
+ " i_weights=i_weights,\n",
281
+ " i_length=i_length,\n",
282
+ " ngrams=ngrams,\n",
283
+ " marker=False)\n",
284
+ " n_p, n_d, seq_len = sequences.shape\n",
285
+ " # Process results\n",
286
+ " sequences = sequences.reshape(-1, seq_len).tolist()\n",
287
+ " for d_idx in range(len(sequences)):\n",
288
+ " draft = sequences[d_idx]\n",
289
+ " if -1 in draft:\n",
290
+ " draft = draft[:draft.index(-1)]\n",
291
+ " sequences[d_idx] = draft\n",
292
+ " decoded_seq = tokenizer.decode(sequences)\n",
293
+ " answers = []\n",
294
+ " for s in decoded_seq:\n",
295
+ " # print(s)\n",
296
+ " answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
297
+ " return answers\n",
298
+ " "
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "questions = {}\n",
308
+ "predictions = {}\n",
309
+ "print(f\"Precision from 1 to {n_drafts}\")\n",
310
+ "for sample in tqdm(triviaqa):\n",
311
+ " # Adaptively select generation length\n",
312
+ " longest = 0\n",
313
+ " shortest = 1000\n",
314
+ " total = 0\n",
315
+ " for answer in sample[\"Answer\"][\"Aliases\"]:\n",
316
+ " tmp = tokenizer.encode([answer], False, False)[0]\n",
317
+ " if len(tmp) > longest:\n",
318
+ " longest = len(tmp)\n",
319
+ " if len(tmp) < shortest:\n",
320
+ " shortest = len(tmp)\n",
321
+ " total += len(tmp)\n",
322
+ " # Evaluation code\n",
323
+ " id = sample[\"QuestionId\"]\n",
324
+ " question = sample[\"Question\"]\n",
325
+ " answer = evaluate_trivia(model_type, question, max_gen_len=longest + 3)\n",
326
+ " predictions[id] = answer\n",
327
+ " questions[id] = question"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": [
336
+ "# Save precisions\n",
337
+ "precisions = {}\n",
338
+ "for i in range(1, n_drafts+1):\n",
339
+ " prec = str(i)\n",
340
+ " responses = {k: v[:i] for k, v in predictions.items()}\n",
341
+ " precisions[prec] = responses"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "# Print some results\n",
351
+ "counter = 0\n",
352
+ "for k in predictions:\n",
353
+ " if counter >= 10:\n",
354
+ " break\n",
355
+ " print(questions[k])\n",
356
+ " print(predictions[k])\n",
357
+ " counter += 1\n",
358
+ " print(\"================\")"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "# Save results\n",
368
+ "os.makedirs(\"../../trivia/\", exist_ok=True)\n",
369
+ "for prec in range(1, n_drafts+1):\n",
370
+ " out_path = f\"../nucleus_extra/trivia_extra/ngram_4trivia_{model_type}_{prec}_4.json\"\n",
371
+ " with open(out_path, \"w\") as f:\n",
372
+ " json.dump(precisions[str(prec)], f, indent=4)"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": null,
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": []
381
+ }
382
+ ],
383
+ "metadata": {
384
+ "kernelspec": {
385
+ "display_name": "Python 3 (ipykernel)",
386
+ "language": "python",
387
+ "name": "python3"
388
+ },
389
+ "language_info": {
390
+ "codemirror_mode": {
391
+ "name": "ipython",
392
+ "version": 3
393
+ },
394
+ "file_extension": ".py",
395
+ "mimetype": "text/x-python",
396
+ "name": "python",
397
+ "nbconvert_exporter": "python",
398
+ "pygments_lexer": "ipython3",
399
+ "version": "3.11.5"
400
+ }
401
+ },
402
+ "nbformat": 4,
403
+ "nbformat_minor": 4
404
+ }