DeepMind LMI Team commited on
Commit
9bdaa77
1 Parent(s): b4ed985

Internal change

Browse files

PiperOrigin-RevId: 495328606

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CONTRIBUTING.md +28 -0
  2. LICENSE +202 -0
  3. README.md +208 -1
  4. compiler/__init__.py +19 -0
  5. compiler/assemble.py +335 -0
  6. compiler/assemble_test.py +120 -0
  7. compiler/basis_inference.py +106 -0
  8. compiler/basis_inference_test.py +140 -0
  9. compiler/compiling.py +92 -0
  10. compiler/craft_graph_to_model.py +238 -0
  11. compiler/craft_graph_to_model_test.py +194 -0
  12. compiler/craft_model_to_transformer.py +76 -0
  13. compiler/expr_to_craft_graph.py +277 -0
  14. compiler/expr_to_craft_graph_test.py +121 -0
  15. compiler/lib.py +371 -0
  16. compiler/lib_test.py +40 -0
  17. compiler/nodes.py +32 -0
  18. compiler/rasp_to_craft_integration_test.py +254 -0
  19. compiler/rasp_to_graph.py +67 -0
  20. compiler/rasp_to_graph_test.py +71 -0
  21. compiler/rasp_to_transformer_integration_test.py +214 -0
  22. compiler/test_cases.py +357 -0
  23. craft/bases.py +247 -0
  24. craft/bases_test.py +158 -0
  25. craft/chamber/categorical_attn.py +167 -0
  26. craft/chamber/categorical_attn_test.py +229 -0
  27. craft/chamber/categorical_mlp.py +168 -0
  28. craft/chamber/categorical_mlp_test.py +164 -0
  29. craft/chamber/numerical_mlp.py +334 -0
  30. craft/chamber/numerical_mlp_test.py +233 -0
  31. craft/chamber/selector_width.py +144 -0
  32. craft/chamber/selector_width_test.py +155 -0
  33. craft/tests_common.py +33 -0
  34. craft/transformers.py +197 -0
  35. craft/transformers_test.py +160 -0
  36. craft/vectorspace_fns.py +162 -0
  37. craft/vectorspace_fns_test.py +166 -0
  38. examples/Visualize_Tracr_Models.ipynb +262 -0
  39. rasp/causal_eval.py +39 -0
  40. rasp/causal_eval_test.py +61 -0
  41. rasp/rasp.py +932 -0
  42. rasp/rasp_test.py +580 -0
  43. transformer/attention.py +160 -0
  44. transformer/compressed_model.py +185 -0
  45. transformer/compressed_model_test.py +318 -0
  46. transformer/encoder.py +135 -0
  47. transformer/encoder_test.py +123 -0
  48. transformer/model.py +199 -0
  49. transformer/model_test.py +275 -0
  50. utils/debugging.py +28 -0
CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We welcome your contributions to this project. Please read the guidance below
4
+ first.
5
+
6
+ ## Contributor License Agreement
7
+
8
+ Contributions to this project must be accompanied by a Contributor License
9
+ Agreement. You (or your employer) retain the copyright to your contribution,
10
+ this simply gives us permission to use and redistribute your contributions as
11
+ part of the project. Head over to <https://cla.developers.google.com/> to see
12
+ your current agreements on file or to sign a new one.
13
+
14
+ You generally only need to submit a CLA once, so if you've already submitted one
15
+ (even if it was for a different project), you probably don't need to do it
16
+ again.
17
+
18
+ ## Code reviews
19
+
20
+ All submissions, including submissions by project members, require review. We
21
+ use GitHub pull requests for this purpose. Consult
22
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23
+ information on using pull requests.
24
+
25
+ ## Community Guidelines
26
+
27
+ This project follows [Google's Open Source Community
28
+ Guidelines](https://opensource.google/conduct/).
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1 +1,208 @@
1
- # tracr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tracr: TRAnsformer Compiler for RASP.
2
+
3
+ Tracr is a compiler for converting RASP programs
4
+ ([Weiss et al. 2021](https://arxiv.org/abs/2106.06981))
5
+ into transformer weights.
6
+
7
+ Directory structure:
8
+
9
+ * `rasp` contains an implementation of RASP embedded in Python.
10
+ * `compiler` contains the compiler itself.
11
+ * `transformer` contains the implementation of the transformer.
12
+ * `craft` contains the intermediate representation used by the compiler:
13
+ essentially a small linear algebra-based library with named dimensions.
14
+
15
+ This is not an officially supported Google product.
16
+
17
+
18
+ ## Installation
19
+
20
+ Installation is currently a bit manual. First, install dependencies:
21
+
22
+ ```
23
+ pip3 install chex einops dm-haiku networkx
24
+ ```
25
+
26
+ Second, clone the repo:
27
+
28
+ ```
29
+ git clone https://github.com/deepmind/tracr
30
+ ```
31
+
32
+ Third, put the resulting folder somewhere in your `PYTHONPATH`
33
+ (eg by placing the `tracr` checkout in the root of your project folder).
34
+
35
+ This will be made easier in the future.
36
+
37
+
38
+ ## Usage example: RASP `reverse` program
39
+
40
+ Consider the RASP `reverse` program:
41
+
42
+ ```
43
+ opp_index = length - indices - 1;
44
+ flip = select(indices, opp_index, ==);
45
+ reverse = aggregate(flip, tokens);
46
+ ```
47
+
48
+ To compile this with Tracr, we would first implement the program using Tracr's
49
+ RASP library:
50
+
51
+ ```python
52
+ from tracr.rasp import rasp
53
+
54
+ length = make_length() # `length` is not a primitive in our implementation.
55
+ opp_index = length - rasp.indices - 1
56
+ flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
57
+ reverse = rasp.Aggregate(flip, rasp.tokens)
58
+ ```
59
+
60
+ Where:
61
+
62
+ ```python
63
+ def make_length():
64
+ all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
65
+ return rasp.SelectorWidth(all_true_selector)
66
+ ```
67
+
68
+ We can then compile the RASP program to a transformer with:
69
+
70
+ ```python
71
+ from tracr.compiler import compiling
72
+
73
+ bos = "BOS"
74
+ model = compiling.compile_rasp_to_model(
75
+ reverse,
76
+ vocab={1, 2, 3},
77
+ max_seq_len=5,
78
+ compiler_bos=bos,
79
+ )
80
+ ```
81
+
82
+ This yields a transformer as a [Haiku](https://github.com/deepmind/dm-haiku) model.
83
+ This model isn't intended to provide _everything_ you might need, but rather serves
84
+ as a kind of "documentation-in-code" for the semantics of the generated parameters.
85
+ The expectation is that the user can then write or contribute an adapter that converts
86
+ parameters from this reference model to another transformer implementation.
87
+
88
+ Using this model we can perform a forward pass:
89
+
90
+ ```python
91
+ >>> out = model.apply([bos, 1, 2, 3])
92
+ >>> out.decoded
93
+ ["BOS", 3, 2, 1]
94
+ ```
95
+
96
+ Success! We have a transformer that reverses its input tokens.
97
+
98
+ Note: compiled models always expect a BOS token in order to support
99
+ selectors which don't attend to any of the input tokens. This is necessary to
100
+ preserve intuitive RASP semantics; the alternative would have been to treat
101
+ all-False selector rows as equivalent to all-True (which is what softmax in an
102
+ attention layer would naturally do). For more details, see our paper.
103
+
104
+ You can also inspect some of the intermediate activations of the model, using
105
+ `out.residuals`, `out.layer_outputs`, and `out.attn_logits`.
106
+
107
+ For more examples of RASP programs we can compile, check out
108
+ [compiler/lib.py](compiler/lib.py).
109
+
110
+ For an interactive example of compiling a model and visualizing its computation,
111
+ check out the notebook at
112
+ [examples/Visualize\_Tracr\_Models.ipynb](examples/Visualize_Tracr_Models.ipynb).
113
+
114
+
115
+ ## Developer README
116
+
117
+ If you'd like to extend Tracr to fit your purposes, here's some information on
118
+ how Tracr works under the hood.
119
+
120
+
121
+ ### How Tracr works conceptually
122
+
123
+ To compile a program, Tracr does the following.
124
+
125
+ 1. **Trace RASP program into a graph representation.** This involves creating
126
+ a graph node for each RASP expression and inferring dependencies between
127
+ these graph nodes.
128
+
129
+ 2. **Infer bases.** Tracr is designed to have each node output to a separate
130
+ subspace of the residual stream. To do this, we first infer the set of all
131
+ possible token values that each node can take, then using that information,
132
+ decide on a subspace for each node, and augment each node in the graph
133
+ with the basis vectors for that node's subspace.
134
+
135
+ 3. **Convert nodes to Craft components.** Craft is the name of our internal
136
+ intermediate representation that does linear algebra on named subspaces. In
137
+ this stage, each expression node is converted to a Craft component that
138
+ actually performs the linear algebra operations necessary to implement the
139
+ expression. This includes converting _sequence operators_ to MLP weights,
140
+ and _selectors_ to weights of attention heads. (We compute the appropriate
141
+ weights directly using the theory of universal approximation for MLPs - no
142
+ gradient descent required!)
143
+
144
+ 4. **Convert Craft graph to Craft model.** In this stage, we convert from
145
+ a graph representation to a layout that looks more like an actual
146
+ transformer. At this stage, we essentially have a working model, but
147
+ with the linear algebra done using Craft rather than JAX + Haiku.
148
+
149
+ 5. **Convert Craft model to Haiku model.** Finally, we convert our
150
+ intermediate representation of the model to a full Haiku model.
151
+
152
+ Two details worth expanding on here are subspaces and corresponding bases.
153
+ Each node writes to a separate subspace of the residual stream,
154
+ where each subspace is simply a unique chunk of the residual stream vector.
155
+ For example, the first node might write to the first 5 components of
156
+ the residual stream; the second node the next 5; and so on. In terms of what
157
+ the embeddings actually associated with each node, Tracr employs two
158
+ different kinds of bases:
159
+
160
+ * **Categorical representation** - in which each unique token value is
161
+ represented as a unique one-hot vector in that node's subspace. This
162
+ is the representation used by default.
163
+ * **Numerical representation** - in which each unique token value is
164
+ mapped to a unique scalar value. This is necessary for some uses
165
+ of the `aggregate` operation - essentially, ones which involve taking
166
+ a mean - and some other operations are represented more efficiently
167
+ with this representation.
168
+
169
+ A final detail is BOS tokens. The compiler relies on beginning-of-sequence
170
+ tokens to in order to implement a number of operations. This is why token
171
+ sequences fed into the final model _must_ start with a BOS token.
172
+
173
+
174
+ ### How Tracr works in practice
175
+
176
+ The flow of compilation execution begins in
177
+ [`compiler/compiling.py`](compiler/compiling.py), in the
178
+ `compile_rasp_to_model` function. This function is fairly short and maps
179
+ directly to the stages outlined above, so don't be afraid to read the source!
180
+
181
+
182
+ ## Running tests
183
+
184
+ We use [`absltest`](https://abseil.io/docs/python/guides/testing), which is
185
+ `unittest`-compatible, and is therefore in turn `pytest`-compatible.
186
+
187
+ First, install test dependencies:
188
+
189
+ ```
190
+ pip3 install absl-py pytest
191
+ ```
192
+
193
+ ```
194
+ # We use `python3 -m pytest` instead of just `pytest` so that the working directory is
195
+ # added to PYTHONPATH.
196
+ # -ra: Report names of tests that failed, were skipped, etc.
197
+ python3 -m pytest -ra
198
+ ```
199
+
200
+ This should take about 60 seconds. If you install `pytest-xdist`, you can run them in
201
+ parallel with:
202
+
203
+ ```
204
+ python3 -m pytest -ra -n auto
205
+ ```
206
+
207
+ However, currently this only shaves off about 10 seconds, since it's bottlenecked by a
208
+ single long-running test.
compiler/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides the main compiler function as a public import."""
16
+
17
+ from tracr.compiler.compiling import compile_rasp_to_model
18
+
19
+ __all__ = ["compile_rasp_to_model"]
compiler/assemble.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Assemble weights of a transformer model from a craft residual stack."""
16
+
17
+ import dataclasses
18
+ from typing import Any, Callable, Optional, Protocol
19
+
20
+ import chex
21
+ import einops
22
+ import haiku as hk
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+ from tracr.craft import bases
27
+ from tracr.craft import transformers
28
+ from tracr.craft import vectorspace_fns
29
+ from tracr.transformer import encoder
30
+ from tracr.transformer import model
31
+
32
+
33
+ @chex.dataclass
34
+ class AssembledTransformerModelOutput:
35
+ decoded: list[Any] # length T.
36
+ unembedded: jax.Array # [B, T] B = 1 always.
37
+ layer_outputs: list[jax.Array] # [B, T, D]
38
+ residuals: list[jax.Array] # [B, T, D]
39
+ attn_logits: list[jax.Array] # [B, T, T, H]
40
+ transformer_output: jax.Array # [B, T, D]
41
+ input_embeddings: jax.Array
42
+
43
+
44
+ class ModelForward(Protocol):
45
+
46
+ def __call__(
47
+ self,
48
+ params: hk.Params,
49
+ emb: jax.Array,
50
+ ) -> model.CompiledTransformerModelOutput:
51
+ """A hk-transformed forward pass through the compiled model."""
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class AssembledTransformerModel:
56
+ """Model architecture and parameters from assembling a model."""
57
+ forward: ModelForward
58
+ get_compiled_model: Callable[[], model.CompiledTransformerModel]
59
+ params: hk.Params
60
+ model_config: model.TransformerConfig
61
+ residual_labels: list[str]
62
+ input_encoder: Optional[encoder.Encoder] = None
63
+ output_encoder: Optional[encoder.Encoder] = None
64
+
65
+ def apply(self, tokens: list[bases.Value]) -> AssembledTransformerModelOutput:
66
+ """Returns output from running the model on a set of input tokens."""
67
+ if self.input_encoder:
68
+ tokens = self.input_encoder.encode(tokens)
69
+ tokens = jnp.array([tokens])
70
+ output = self.forward(self.params, tokens)
71
+ decoded = output.unembedded_output[0].tolist()
72
+ if self.output_encoder:
73
+ decoded = self.output_encoder.decode(decoded)
74
+
75
+ if self.input_encoder.bos_token:
76
+ # Special case for decoding the bos token position, for which the output
77
+ # decoder might have unspecified behavior.
78
+ decoded = [self.input_encoder.bos_token] + decoded[1:]
79
+
80
+ return AssembledTransformerModelOutput(
81
+ decoded=decoded,
82
+ unembedded=output.unembedded_output,
83
+ layer_outputs=output.transformer_output.layer_outputs,
84
+ residuals=output.transformer_output.residuals,
85
+ attn_logits=output.transformer_output.attn_logits,
86
+ transformer_output=output.transformer_output.output,
87
+ input_embeddings=output.transformer_output.input_embeddings)
88
+
89
+
90
+ @dataclasses.dataclass
91
+ class EmbeddingModules:
92
+ """Modules for embedding and tokens and positions and unembedding results."""
93
+ token_embed: model.CallableHaikuModule
94
+ pos_embed: model.CallableHaikuModule
95
+ unembed: model.CallableHaikuModule
96
+
97
+
98
+ def _get_model_config_and_module_names(
99
+ craft_model: transformers.SeriesWithResiduals
100
+ ) -> tuple[model.TransformerConfig, list[str]]:
101
+ """Returns model config and locations (in params) for halflayers."""
102
+
103
+ multi_attn_heads: list[list[transformers.AttentionHead]] = []
104
+ mlps: list[transformers.MLP] = []
105
+ module_names: list[str] = []
106
+
107
+ candidate_module_names = []
108
+ for layer in range(len(craft_model.blocks)):
109
+ candidate_module_names.append(f"transformer/layer_{layer}/attn")
110
+ candidate_module_names.append(f"transformer/layer_{layer}/mlp")
111
+ candidate_module_names = iter(candidate_module_names)
112
+
113
+ for module in craft_model.blocks:
114
+ if isinstance(module, transformers.MLP):
115
+ mlps.append(module)
116
+ layer_type = "mlp"
117
+ else:
118
+ multi_attn_heads.append(list(module.as_multi().heads()))
119
+ layer_type = "attn"
120
+ # Find next layer with the necessary type. Modules in-between, that are not
121
+ # added to module_names will be disabled later by setting all weights to 0.
122
+ module_name = next(candidate_module_names)
123
+ while layer_type not in module_name:
124
+ module_name = next(candidate_module_names)
125
+ module_names.append(module_name)
126
+
127
+ num_layers = int(module_names[-1].split("_")[1].split("/")[0]) + 1
128
+ heads = sum(multi_attn_heads, [])
129
+
130
+ if multi_attn_heads:
131
+ num_heads = max(len(heads) for heads in multi_attn_heads)
132
+ key_size = max(max(head.w_qk.matrix.shape) for head in heads)
133
+ else:
134
+ num_heads, key_size = 1, 1
135
+
136
+ if mlps:
137
+ mlp_hidden_size = max(mlp.fst.output_space.num_dims for mlp in mlps)
138
+ else:
139
+ mlp_hidden_size = 1
140
+
141
+ model_config = model.TransformerConfig(
142
+ num_heads=num_heads,
143
+ num_layers=num_layers,
144
+ key_size=key_size,
145
+ mlp_hidden_size=mlp_hidden_size,
146
+ dropout_rate=0.,
147
+ activation_function=jax.nn.relu,
148
+ layer_norm=False,
149
+ causal=False,
150
+ )
151
+
152
+ return model_config, module_names
153
+
154
+
155
+ def _make_embedding_modules(
156
+ residual_space: bases.VectorSpaceWithBasis,
157
+ tokens_space: bases.VectorSpaceWithBasis,
158
+ indices_space: bases.VectorSpaceWithBasis,
159
+ output_space: bases.VectorSpaceWithBasis) -> EmbeddingModules:
160
+ """Creates embedding and unembedding modules from vector spaces.
161
+
162
+ Args:
163
+ residual_space: Full residual space of the model.
164
+ tokens_space: Subspace to embed tokens to.
165
+ indices_space: Subspace to embed indices/position embeddings to.
166
+ output_space: Subspace to unembed outputs from.
167
+
168
+ Returns:
169
+ EmbeddingModules containing modules for token embeddings, position
170
+ embeddings and unembeddings.
171
+ """
172
+ tokens_to_res = vectorspace_fns.project(tokens_space, residual_space)
173
+
174
+ # If we use the 'one' direction, make sure all inputs have a 1 here
175
+ one_dir = bases.BasisDirection("one")
176
+ if one_dir in residual_space:
177
+ one_to_res = vectorspace_fns.Linear.from_action(
178
+ tokens_space, residual_space,
179
+ lambda x: residual_space.vector_from_basis_direction(one_dir))
180
+ tokens_to_res = vectorspace_fns.Linear.combine_in_parallel(
181
+ [tokens_to_res, one_to_res])
182
+
183
+ # Token embeddings.
184
+ res_to_out = vectorspace_fns.project(residual_space, output_space)
185
+ token_embed = hk.Embed(
186
+ embedding_matrix=tokens_to_res.matrix, name="token_embed")
187
+
188
+ # Positional embeddings.
189
+ index_to_res = vectorspace_fns.project(indices_space, residual_space)
190
+ # The zeroth position should not have any positional embeddings,
191
+ # so we add one line of padding at the zeroth position.
192
+ pos_matrix = np.concatenate(
193
+ [np.zeros((1, residual_space.num_dims)), index_to_res.matrix], axis=0)
194
+ pos_embed = hk.Embed(embedding_matrix=pos_matrix, name="pos_embed")
195
+
196
+ def unembed(x, use_unembed_argmax):
197
+ out = x @ res_to_out.matrix
198
+ if use_unembed_argmax:
199
+ return jnp.argmax(out, axis=-1)
200
+ elif out.shape[-1] == 1:
201
+ return out.squeeze(-1)
202
+ return out
203
+
204
+ unembed_mod = hk.to_module(unembed)()
205
+ return EmbeddingModules(
206
+ token_embed=token_embed, pos_embed=pos_embed, unembed=unembed_mod)
207
+
208
+
209
+ def assemble_craft_model(
210
+ craft_model: transformers.SeriesWithResiduals,
211
+ tokens_space: bases.VectorSpaceWithBasis,
212
+ indices_space: bases.VectorSpaceWithBasis,
213
+ output_space: bases.VectorSpaceWithBasis,
214
+ categorical_output: bool,
215
+ causal: bool = False,
216
+ ) -> AssembledTransformerModel:
217
+ """Assembles the given components into a Haiku model with parameters.
218
+
219
+ Args:
220
+ craft_model: Model to assemble weights for.
221
+ tokens_space: Vectorspace to embed the input tokens to.
222
+ indices_space: Vectorspace to embed the indices to (position encodings).
223
+ output_space: Vectorspace that the model will write outputs to that should
224
+ be unembedded.
225
+ categorical_output: Whether the output is categorical. If True, we take an
226
+ argmax when unembedding.
227
+ causal: Whether to output a causally-masked model.
228
+
229
+ Returns:
230
+ An AssembledTransformerModel that contains the model and parameters of the
231
+ assembled transformer.
232
+ """
233
+ # TODO(b/255936413): Make embeddings only retain the tokens and indices that
234
+ # are actually used.
235
+ # TODO(b/255936496): Think about enabling layer norm and reversing it somehow
236
+
237
+ model_config, module_names = _get_model_config_and_module_names(craft_model)
238
+ model_config.causal = causal
239
+
240
+ residual_space = bases.join_vector_spaces(craft_model.residual_space,
241
+ tokens_space, indices_space,
242
+ output_space)
243
+ residual_labels = [str(basis_dir) for basis_dir in residual_space.basis]
244
+
245
+ # Build model with embedding and unembedding layers
246
+ def get_compiled_model():
247
+ transformer = model.Transformer(model_config)
248
+ embed_modules = _make_embedding_modules(
249
+ residual_space=residual_space,
250
+ tokens_space=tokens_space,
251
+ indices_space=indices_space,
252
+ output_space=output_space)
253
+ return model.CompiledTransformerModel(
254
+ transformer=transformer,
255
+ token_embed=embed_modules.token_embed,
256
+ position_embed=embed_modules.pos_embed,
257
+ unembed=embed_modules.unembed,
258
+ use_unembed_argmax=categorical_output)
259
+
260
+ @hk.without_apply_rng
261
+ @hk.transform
262
+ def forward(emb):
263
+ compiled_model = get_compiled_model()
264
+ return compiled_model(emb, use_dropout=False)
265
+
266
+ params = forward.init(jax.random.PRNGKey(0), jnp.array([[1, 2, 3]]))
267
+
268
+ for key in params:
269
+ if "transformer" in key:
270
+ for par in params[key]:
271
+ params[key][par] = np.zeros_like(params[key][par])
272
+
273
+ # Assemble attention and MLP weights.
274
+ project = lambda space: vectorspace_fns.project(residual_space, space).matrix
275
+
276
+ for module_name, module in zip(module_names, craft_model.blocks):
277
+ if isinstance(module, transformers.MLP):
278
+ hidden_size = module.fst.output_space.num_dims
279
+ residual_to_fst_input = project(module.fst.input_space)
280
+ snd_output_to_residual = project(module.snd.output_space).T
281
+ params[f"{module_name}/linear_1"]["w"][:, :hidden_size] = (
282
+ residual_to_fst_input @ module.fst.matrix)
283
+ params[f"{module_name}/linear_2"]["w"][:hidden_size, :] = (
284
+ module.snd.matrix @ snd_output_to_residual)
285
+ else: # Attention module
286
+ query, key, value, linear = [], [], [], []
287
+ for head in module.as_multi().heads():
288
+ key_size = head.w_qk.matrix.shape[1]
289
+ query_mat = np.zeros((residual_space.num_dims, model_config.key_size))
290
+ residual_to_query = project(head.w_qk.left_space)
291
+ query_mat[:, :key_size] = residual_to_query @ head.w_qk.matrix
292
+ query.append(query_mat)
293
+
294
+ key_mat = np.zeros((residual_space.num_dims, model_config.key_size))
295
+ key_mat[:, :key_size] = project(head.w_qk.right_space)
296
+ key.append(key_mat)
297
+
298
+ value_size = head.w_ov.matrix.shape[1]
299
+ value_mat = np.zeros((residual_space.num_dims, model_config.key_size))
300
+ residual_to_ov_input = project(head.w_ov.input_space)
301
+ value_mat[:, :value_size] = residual_to_ov_input @ head.w_ov.matrix
302
+ value.append(value_mat)
303
+
304
+ linear_mat = np.zeros((model_config.key_size, residual_space.num_dims))
305
+ linear_mat[:value_size, :] = project(head.w_ov.output_space).T
306
+ linear.append(linear_mat)
307
+
308
+ # Fill up heads that are not used with zero weights
309
+ for _ in range(model_config.num_heads - module.as_multi().num_heads):
310
+ query.append(np.zeros_like(query[0]))
311
+ key.append(np.zeros_like(key[0]))
312
+ value.append(np.zeros_like(value[0]))
313
+ linear.append(np.zeros_like(linear[0]))
314
+
315
+ query = einops.rearrange(query,
316
+ "heads input output -> input (heads output)")
317
+ key = einops.rearrange(key, "heads input output -> input (heads output)")
318
+ value = einops.rearrange(value,
319
+ "heads input output -> input (heads output)")
320
+ linear = einops.rearrange(linear,
321
+ "heads input output -> (heads input) output")
322
+
323
+ params[f"{module_name}/query"]["w"][:, :] = query
324
+ params[f"{module_name}/key"]["w"][:, :] = key
325
+ params[f"{module_name}/value"]["w"][:, :] = value
326
+ params[f"{module_name}/linear"]["w"][:, :] = linear
327
+
328
+ params = jax.tree_util.tree_map(jnp.array, params)
329
+ return AssembledTransformerModel(
330
+ forward=forward.apply,
331
+ get_compiled_model=get_compiled_model,
332
+ params=params,
333
+ model_config=model_config,
334
+ residual_labels=residual_labels,
335
+ )
compiler/assemble_test.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for transformer.assemble."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import haiku as hk
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from tracr.compiler import assemble
24
+ from tracr.craft import bases
25
+
26
+
27
+ class AssembleTest(parameterized.TestCase):
28
+
29
+ def test_token_embedding_produces_correct_embedding(self):
30
+ # Token embeddings should be one-hot embeddings of the input integers
31
+ # into the token subspace of residual_space
32
+ input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2))
33
+ indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3))
34
+ output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2))
35
+ residual_space = bases.join_vector_spaces(input_space, indices_space,
36
+ output_space)
37
+
38
+ @hk.without_apply_rng
39
+ @hk.transform
40
+ def token_pos_embed(tokens):
41
+ embed_modules = assemble._make_embedding_modules(
42
+ residual_space=residual_space,
43
+ tokens_space=input_space,
44
+ indices_space=indices_space,
45
+ output_space=output_space)
46
+ return embed_modules.token_embed(tokens)
47
+
48
+ tokens = jnp.array([0, 0, 1])
49
+ expected_token_embeddings = jnp.array([[1, 0, 0, 0, 0, 0, 0],
50
+ [1, 0, 0, 0, 0, 0, 0],
51
+ [0, 1, 0, 0, 0, 0, 0]])
52
+
53
+ params = token_pos_embed.init(jax.random.PRNGKey(0), tokens)
54
+ embeddings = token_pos_embed.apply(params, tokens)
55
+ np.testing.assert_allclose(embeddings, expected_token_embeddings)
56
+
57
+ def test_position_embedding_produces_correct_embedding(self):
58
+ # Position embeddings should be one-hot embeddings of the input integers
59
+ # (representing indices) into the indices subspace of residual_space
60
+ input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2))
61
+ indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3))
62
+ output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2))
63
+ residual_space = bases.join_vector_spaces(input_space, indices_space,
64
+ output_space)
65
+
66
+ @hk.without_apply_rng
67
+ @hk.transform
68
+ def token_pos_embed(tokens):
69
+ embed_modules = assemble._make_embedding_modules(
70
+ residual_space=residual_space,
71
+ tokens_space=input_space,
72
+ indices_space=indices_space,
73
+ output_space=output_space)
74
+ return embed_modules.pos_embed(jnp.indices(tokens.shape)[-1])
75
+
76
+ tokens = jnp.array([3, 0, 0, 1])
77
+ expected_pos_embeddings = jnp.array([[0, 0, 0, 0, 0, 0, 0],
78
+ [0, 0, 1, 0, 0, 0, 0],
79
+ [0, 0, 0, 1, 0, 0, 0],
80
+ [0, 0, 0, 0, 1, 0, 0]])
81
+
82
+ params = token_pos_embed.init(jax.random.PRNGKey(0), tokens)
83
+ embeddings = token_pos_embed.apply(params, tokens)
84
+ np.testing.assert_allclose(embeddings, expected_pos_embeddings)
85
+
86
+ def test_unembedding(self):
87
+ # Prepend numbers to preserve basis order [input, index, output]
88
+ input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2))
89
+ indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3))
90
+ output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2))
91
+ residual_space = bases.join_vector_spaces(input_space, indices_space,
92
+ output_space)
93
+
94
+ @hk.without_apply_rng
95
+ @hk.transform
96
+ def unembed(embeddings):
97
+ embed_modules = assemble._make_embedding_modules(
98
+ residual_space=residual_space,
99
+ tokens_space=input_space,
100
+ indices_space=indices_space,
101
+ output_space=output_space)
102
+ return embed_modules.unembed(embeddings, use_unembed_argmax=True)
103
+
104
+ embeddings = jnp.array([
105
+ # pylint: disable=g-no-space-after-comment
106
+ #inp| indices| out | < spaces
107
+ #0 1 0 1 2 0 1 < values in spaces
108
+ [0, 0, 0, 0, 0, 0, 1],
109
+ [0, 0, 0, 0, 0, 1, 0],
110
+ [0, 0, 0, 0, 0, 0, 1]
111
+ ])
112
+ expected_tokens = jnp.array([1, 0, 1])
113
+
114
+ params = unembed.init(jax.random.PRNGKey(0), embeddings)
115
+ tokens = unembed.apply(params, embeddings)
116
+ np.testing.assert_allclose(tokens, expected_tokens)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ absltest.main()
compiler/basis_inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Inferring the vector spaces taken on by certain operations."""
16
+
17
+ import dataclasses
18
+ import itertools
19
+
20
+ import networkx as nx
21
+ from tracr.compiler import nodes
22
+ from tracr.craft import bases
23
+ from tracr.rasp import rasp
24
+ from tracr.utils import errors
25
+
26
+ Node = nodes.Node
27
+
28
+
29
+ @dataclasses.dataclass
30
+ class InferBasesOutput:
31
+ graph: nx.DiGraph
32
+
33
+
34
+ def infer_bases(
35
+ graph: nx.DiGraph,
36
+ sink: Node,
37
+ vocab: set[rasp.Value],
38
+ max_seq_len: int,
39
+ ) -> None:
40
+ """Infers in-place the possible output values and vector bases of the SOps."""
41
+
42
+ def compute_value_set(sop: rasp.SOp) -> set[rasp.Value]:
43
+ """Computes value set using already-computed predecessor value sets."""
44
+ if sop is rasp.tokens:
45
+ return vocab
46
+ elif sop is rasp.indices:
47
+ return set(range(max_seq_len))
48
+ elif isinstance(sop, rasp.SelectorWidth):
49
+ return set(range(0, max_seq_len + 1))
50
+ elif isinstance(sop, rasp.Full):
51
+ return {sop.fill}
52
+ elif isinstance(sop, rasp.Map):
53
+ inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET]
54
+ out = set()
55
+ for x in inner_value_set:
56
+ res = errors.ignoring_arithmetic_errors(sop.f)(x)
57
+ if res is not None:
58
+ out.add(res)
59
+ return out
60
+ elif isinstance(sop, rasp.SequenceMap):
61
+ f_ignore_error = errors.ignoring_arithmetic_errors(sop.f)
62
+ fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET]
63
+ snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET]
64
+ out = set()
65
+ for l, r in itertools.product(fst_value_set, snd_value_set):
66
+ res = f_ignore_error(l, r)
67
+ if res is not None:
68
+ out.add(res)
69
+ return out
70
+ elif isinstance(sop, rasp.Aggregate):
71
+ if rasp.is_categorical(sop):
72
+ # Simply pass on the value set of the underlying S-Op.
73
+ return graph.nodes[sop.sop.label][nodes.VALUE_SET]
74
+ elif rasp.is_numerical(sop):
75
+ # TODO(b/255936408): This doesn't work if we average arbitrary values.
76
+ # But most examples only average binary variables.
77
+ sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET]
78
+ if {int(x) for x in sop_value_set} != {0, 1}:
79
+ raise NotImplementedError(
80
+ "Attention patterns can currently only "
81
+ "average binary variables. Not:", sop_value_set)
82
+
83
+ value_set = set()
84
+ for value in sop_value_set:
85
+ for length in range(1, max_seq_len + 1):
86
+ value_set.add(value / length)
87
+ return value_set
88
+ raise ValueError(f"Unsupported S-Op: {sop}")
89
+
90
+ for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]):
91
+ expr = graph.nodes[node_id][nodes.EXPR]
92
+
93
+ if not isinstance(expr, rasp.SOp):
94
+ # Only S-Ops have output vector spaces.
95
+ continue
96
+
97
+ value_set = compute_value_set(expr)
98
+ graph.nodes[node_id][nodes.VALUE_SET] = value_set
99
+
100
+ if rasp.is_categorical(expr):
101
+ out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set)
102
+ elif rasp.is_numerical(expr):
103
+ out_space = bases.VectorSpaceWithBasis.from_names([expr.label])
104
+ else:
105
+ raise ValueError(f"Unsupported S-Op type: {expr.type}")
106
+ graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis
compiler/basis_inference_test.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for compiler.basis_inference."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from tracr.compiler import basis_inference
20
+ from tracr.compiler import nodes
21
+ from tracr.compiler import rasp_to_graph
22
+ from tracr.rasp import rasp
23
+
24
+
25
+ class InferBasesTest(parameterized.TestCase):
26
+
27
+ def test_arithmetic_error_logs_warning(self):
28
+ program = rasp.numerical(rasp.Map(lambda x: 1 / x, rasp.tokens))
29
+ extracted = rasp_to_graph.extract_rasp_graph(program)
30
+ vocab = {0, 1, 2}
31
+ with self.assertLogs(level="WARNING"):
32
+ basis_inference.infer_bases(
33
+ extracted.graph,
34
+ extracted.sink,
35
+ vocab,
36
+ max_seq_len=1,
37
+ )
38
+
39
+ @parameterized.parameters(({1, 2, 3}, {2, 3, 4}), ({0, 5}, {1, 6}))
40
+ def test_one_edge(self, vocab, expected_value_set):
41
+ program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))
42
+ extracted = rasp_to_graph.extract_rasp_graph(program)
43
+
44
+ basis_inference.infer_bases(
45
+ extracted.graph,
46
+ extracted.sink,
47
+ vocab,
48
+ max_seq_len=1,
49
+ )
50
+
51
+ self.assertSetEqual(
52
+ extracted.graph.nodes[program.label][nodes.VALUE_SET],
53
+ expected_value_set,
54
+ )
55
+
56
+ def test_primitive_close_to_tip(self):
57
+ intermediate = rasp.categorical(rasp.tokens + 1)
58
+ intermediate = rasp.categorical(intermediate + intermediate)
59
+ program = rasp.categorical(intermediate + rasp.indices)
60
+ extracted = rasp_to_graph.extract_rasp_graph(program)
61
+
62
+ basis_inference.infer_bases(
63
+ extracted.graph,
64
+ extracted.sink,
65
+ {0, 1},
66
+ max_seq_len=2,
67
+ )
68
+
69
+ self.assertSetEqual(
70
+ extracted.graph.nodes[program.label][nodes.VALUE_SET],
71
+ {2, 3, 4, 5},
72
+ )
73
+ self.assertSetEqual(
74
+ extracted.graph.nodes[intermediate.label][nodes.VALUE_SET],
75
+ {2, 3, 4},
76
+ )
77
+
78
+ def test_categorical_aggregate(self):
79
+ program = rasp.categorical(
80
+ rasp.Aggregate(
81
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
82
+ rasp.indices,
83
+ ))
84
+
85
+ extracted = rasp_to_graph.extract_rasp_graph(program)
86
+
87
+ basis_inference.infer_bases(
88
+ extracted.graph,
89
+ extracted.sink,
90
+ {0, 1},
91
+ max_seq_len=3,
92
+ )
93
+
94
+ self.assertSetEqual(
95
+ extracted.graph.nodes[program.label][nodes.VALUE_SET],
96
+ {0, 1, 2},
97
+ )
98
+
99
+ def test_numerical_aggregate(self):
100
+ program = rasp.numerical(
101
+ rasp.Aggregate(
102
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
103
+ rasp.indices,
104
+ ))
105
+
106
+ extracted = rasp_to_graph.extract_rasp_graph(program)
107
+
108
+ basis_inference.infer_bases(
109
+ extracted.graph,
110
+ extracted.sink,
111
+ {0, 1},
112
+ max_seq_len=2,
113
+ )
114
+
115
+ self.assertSetEqual(
116
+ extracted.graph.nodes[program.label][nodes.VALUE_SET],
117
+ {0, 1, 1 / 2},
118
+ )
119
+
120
+ def test_selector_width(self):
121
+ program = rasp.SelectorWidth(
122
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ))
123
+
124
+ extracted = rasp_to_graph.extract_rasp_graph(program)
125
+
126
+ basis_inference.infer_bases(
127
+ extracted.graph,
128
+ extracted.sink,
129
+ {0, 1},
130
+ max_seq_len=2,
131
+ )
132
+
133
+ self.assertSetEqual(
134
+ extracted.graph.nodes[program.label][nodes.VALUE_SET],
135
+ {0, 1, 2},
136
+ )
137
+
138
+
139
+ if __name__ == "__main__":
140
+ absltest.main()
compiler/compiling.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Combines all steps of compiling a RASP program."""
16
+
17
+ from tracr.compiler import assemble
18
+ from tracr.compiler import basis_inference
19
+ from tracr.compiler import craft_graph_to_model
20
+ from tracr.compiler import craft_model_to_transformer
21
+ from tracr.compiler import expr_to_craft_graph
22
+ from tracr.compiler import rasp_to_graph
23
+ from tracr.craft import bases
24
+ from tracr.rasp import rasp
25
+
26
+ COMPILER_BOS = "compiler_bos"
27
+ COMPILER_PAD = "compiler_pad"
28
+
29
+
30
+ def compile_rasp_to_model(
31
+ program: rasp.SOp,
32
+ vocab: set[rasp.Value],
33
+ max_seq_len: int,
34
+ causal: bool = False,
35
+ compiler_bos: str = COMPILER_BOS,
36
+ compiler_pad: str = COMPILER_PAD,
37
+ mlp_exactness: int = 100) -> assemble.AssembledTransformerModel:
38
+ """Compile a RASP program to transformer weights.
39
+
40
+ Args:
41
+ program: the RASP program to compile.
42
+ vocab: the set of vocab tokens expected by RASP.
43
+ max_seq_len: the maximum sequence length for the compiled model.
44
+ causal: if True, outputs a model with causal masking.
45
+ compiler_bos: the name of the special BOS token that will be added by the
46
+ compiler. Must not be present in the vocab.
47
+ compiler_pad: the name of the special PAD token that will be added by the
48
+ compiler. Must not be present in the vocab.
49
+ mlp_exactness: Controls the approximation of the MLP layers. In theory,
50
+ larger values yield a better approximation. But too large values can cause
51
+ numerical issues due to large parameter norms. Reasonable values are
52
+ between 1 and 100.
53
+
54
+ Returns:
55
+ The compiled model.
56
+ """
57
+
58
+ if compiler_bos in vocab:
59
+ raise ValueError("Compiler BOS token must not be present in the vocab. "
60
+ f"Found '{compiler_bos}' in {vocab}")
61
+
62
+ if compiler_pad in vocab:
63
+ raise ValueError("Compiler PAD token must not be present in the vocab. "
64
+ f"Found '{compiler_pad}' in {vocab}")
65
+
66
+ extracted = rasp_to_graph.extract_rasp_graph(program)
67
+ graph, sources, sink = extracted.graph, extracted.sources, extracted.sink
68
+
69
+ basis_inference.infer_bases(
70
+ graph,
71
+ sink,
72
+ vocab,
73
+ max_seq_len,
74
+ )
75
+
76
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(
77
+ graph,
78
+ bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos),
79
+ mlp_exactness=mlp_exactness,
80
+ )
81
+
82
+ craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources)
83
+
84
+ return craft_model_to_transformer.craft_model_to_transformer(
85
+ craft_model=craft_model,
86
+ graph=graph,
87
+ sink=sink,
88
+ max_seq_len=max_seq_len,
89
+ causal=causal,
90
+ compiler_bos=compiler_bos,
91
+ compiler_pad=compiler_pad,
92
+ )
compiler/craft_graph_to_model.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Create a craft model from a computational graph."""
16
+
17
+ import collections
18
+ from typing import Sequence
19
+
20
+ import networkx as nx
21
+ from tracr.compiler import nodes
22
+ from tracr.craft import bases
23
+ from tracr.craft import transformers
24
+ from tracr.rasp import rasp
25
+
26
+ Node = nodes.Node
27
+ NodeID = nodes.NodeID
28
+
29
+
30
+ def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node],
31
+ node: Node) -> int:
32
+ """Returns the lengths of the longest path from sources to node.
33
+
34
+ Only SOps count towards the length of a path.
35
+
36
+ Args:
37
+ graph: DAG to compute longest path in.
38
+ sources: List of starting nodes, longest path will be a maximum over all.
39
+ node: Target node.
40
+
41
+ Returns:
42
+ Number of steps needed for the longest path from the source to the node, or
43
+ -1 if there is no path from any of the sources to the target node.
44
+ """
45
+ if node in sources:
46
+ return 0
47
+
48
+ def num_sops(path: Sequence[NodeID]) -> int:
49
+ num = 0
50
+ for node_id in path:
51
+ if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp):
52
+ num += 1
53
+ return num
54
+
55
+ result = -1
56
+ for source in sources:
57
+ all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID])
58
+ longest_path_len = max(map(num_sops, all_paths), default=-1) - 1
59
+ if longest_path_len > result:
60
+ result = longest_path_len
61
+ return result
62
+
63
+
64
+ def _node_is_attn(node: Node) -> bool:
65
+ """Returns True if node is an attention layer."""
66
+ return nodes.MODEL_BLOCK in node and isinstance(
67
+ node[nodes.MODEL_BLOCK],
68
+ (transformers.AttentionHead, transformers.MultiAttentionHead))
69
+
70
+
71
+ def _node_is_mlp(node: Node) -> bool:
72
+ """Returns True if node is an MLP layer."""
73
+ return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK],
74
+ transformers.MLP)
75
+
76
+
77
+ def _node_is_residual_block(node: Node) -> bool:
78
+ """Returns True if node is a valid residual block (Attn followed by MLP)."""
79
+ block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
80
+ if block and isinstance(block, transformers.SeriesWithResiduals):
81
+ if len(block.blocks) == 2:
82
+ attn, mlp = block.blocks
83
+ if (isinstance(
84
+ attn,
85
+ (transformers.AttentionHead, transformers.MultiAttentionHead)) and
86
+ isinstance(mlp, transformers.MLP)):
87
+ return True
88
+ return False
89
+
90
+
91
+ def _all_attn_nodes(node_list: Sequence[Node]) -> bool:
92
+ """Returns True iff all nodes are attention layers (or nodes is empty)."""
93
+ for node in node_list:
94
+ if not _node_is_attn(node):
95
+ return False
96
+ return True
97
+
98
+
99
+ def _all_mlp_nodes(node_list: Sequence[Node]) -> bool:
100
+ """Returns True iff all nodes are MLP layers (or nodes is empty)."""
101
+ for node in node_list:
102
+ if not _node_is_mlp(node):
103
+ return False
104
+ return True
105
+
106
+
107
+ def _allocate_modules_to_layers(graph: nx.DiGraph,
108
+ sources: Sequence[Node]) -> dict[int, int]:
109
+ """Allocate all nodes in compute graph to layers.
110
+
111
+ First, computes the longest path from the input to each node that is a model
112
+ component (not input and output nodes). The longest path to a model component
113
+ (its "depth") determines a layer in which we can place it while ensuring that
114
+ all necessary previous computations have already happened.
115
+
116
+ This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...]
117
+
118
+ In the special case where there are only Attention layers at one depth level
119
+ and only MLP layers in the next depth layer, they are treated as if there
120
+ are at the same depth because attention layers always come before MLP layers
121
+ for the same depth.
122
+
123
+ Args:
124
+ graph: RASP graph with craft blocks.
125
+ sources: List of input nodes
126
+
127
+ Returns:
128
+ A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
129
+ are in the order attention, mlp, attention, mlp, ...
130
+ """
131
+ layer_allocation: dict[int, int] = collections.defaultdict(lambda: -1)
132
+ depth_by_node_id: dict[int, int] = dict()
133
+ nodes_by_depth: dict[int, list[Node]] = collections.defaultdict(list)
134
+
135
+ # Compute depth of all model components (longest path from source to node)
136
+ for node_id, node in graph.nodes.items():
137
+ if (_node_is_mlp(node) or _node_is_attn(node)
138
+ or _node_is_residual_block(node)):
139
+ # Node is a model component
140
+ longest_path_len = _get_longest_path_length_to_node(graph, sources, node)
141
+ depth_by_node_id[node_id] = longest_path_len
142
+ nodes_by_depth[longest_path_len].append(node)
143
+
144
+ # If at level `depth` there are only attention heads and at level `depths + 1`
145
+ # there are only MLPs, we can condense them into one level
146
+ # TODO(b/255936816): Think about improving this heuristic. The heuristic is
147
+ # not optimal, and only catches very basic opportunities for optimization. It
148
+ # is easy to come up with opportunities for optimization that it does not
149
+ # catch.
150
+ min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys())
151
+ depth = min_depth
152
+ while depth < max_depth:
153
+ if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes(
154
+ nodes_by_depth[depth + 1]):
155
+ # Condense by decrementing the depth of all nodes starting from depth+1
156
+ for update_depth in range(depth + 1, max_depth + 1):
157
+ for node in nodes_by_depth[update_depth]:
158
+ node_id = node[nodes.ID]
159
+ depth_by_node_id[node_id] = update_depth - 1
160
+ nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth])
161
+ nodes_by_depth[update_depth] = []
162
+ max_depth -= 1
163
+ depth += 1
164
+
165
+ # Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ...
166
+ current_layer = 0
167
+ current_depth = 1
168
+ for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]):
169
+ while depth > current_depth:
170
+ current_depth += 1
171
+ current_layer += 2
172
+ if depth == current_depth:
173
+ if _node_is_residual_block(graph.nodes[node_id]):
174
+ layer_allocation[node_id] = current_layer
175
+ else:
176
+ is_mlp = _node_is_mlp(graph.nodes[node_id])
177
+ layer_allocation[node_id] = current_layer + int(is_mlp)
178
+
179
+ return layer_allocation
180
+
181
+
182
+ def craft_graph_to_model(
183
+ graph: nx.DiGraph,
184
+ sources: Sequence[Node]) -> transformers.SeriesWithResiduals:
185
+ """Translates a RASP graph with craft blocks into a full craft model.
186
+
187
+ 1. Allocate modules to layers, assuming layers in the order
188
+ 2. Creates subspaces for all inputs and outputs, and builds residual stream.
189
+ 3. Assembles everything into a craft model and returns it.
190
+
191
+ Args:
192
+ graph: RASP graph with craft blocks.
193
+ sources: List of input nodes
194
+
195
+ Returns:
196
+ A craft model that can be compiled to model weights.
197
+
198
+ Raises:
199
+ ValueError: On invalid input (if the craft_graph does not have craft blocks
200
+ already specified)
201
+ """
202
+ layer_allocation = _allocate_modules_to_layers(graph, sources)
203
+ blocks_by_layer = collections.defaultdict(list)
204
+ model_blocks = []
205
+
206
+ residual_space = bases.VectorSpaceWithBasis([])
207
+
208
+ for node_id, layer_no in layer_allocation.items():
209
+ node = graph.nodes[node_id]
210
+ block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
211
+
212
+ if _node_is_residual_block(node):
213
+ assert isinstance(block, transformers.SeriesWithResiduals)
214
+ assert len(block.blocks) == 2
215
+ residual_space = bases.join_vector_spaces(residual_space,
216
+ block.blocks[0].residual_space,
217
+ block.blocks[1].residual_space)
218
+ blocks_by_layer[layer_no].append(block.blocks[0])
219
+ blocks_by_layer[layer_no + 1].append(block.blocks[1])
220
+ elif block:
221
+ residual_space = bases.join_vector_spaces(
222
+ residual_space, node[nodes.MODEL_BLOCK].residual_space)
223
+ blocks_by_layer[layer_no].append(block)
224
+
225
+ for layer_no, layer_blocks in sorted(
226
+ blocks_by_layer.items(), key=lambda x: x[0]):
227
+ for block in layer_blocks:
228
+ block.residual_space = residual_space
229
+
230
+ if layer_blocks:
231
+ if layer_no % 2 == 0: # Attention Layer
232
+ multi_head_attn = transformers.MultiAttentionHead(layer_blocks)
233
+ model_blocks.append(multi_head_attn)
234
+ else: # MLP Layer
235
+ parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks)
236
+ model_blocks.append(parallel_mlp)
237
+
238
+ return transformers.SeriesWithResiduals(model_blocks)
compiler/craft_graph_to_model_test.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for compiler.craft_graph_to_model."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import networkx as nx
20
+ from tracr.compiler import craft_graph_to_model
21
+ from tracr.compiler import nodes
22
+ from tracr.compiler import rasp_to_graph
23
+ from tracr.craft import bases
24
+ from tracr.craft.chamber import categorical_attn
25
+ from tracr.craft.chamber import categorical_mlp
26
+ from tracr.rasp import rasp
27
+
28
+
29
+ class CraftAllocateModulesToLayersTest(parameterized.TestCase):
30
+
31
+ def _get_dummy_block(self, block_type):
32
+ if block_type == "ATTN":
33
+ return categorical_attn.categorical_attn(
34
+ query_space=bases.VectorSpaceWithBasis.from_names(["query"]),
35
+ key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]),
36
+ value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]),
37
+ output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
38
+ bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]),
39
+ one_space=bases.VectorSpaceWithBasis.from_names(["one"]),
40
+ attn_fn=lambda x, y: True,
41
+ )
42
+ elif block_type == "MLP":
43
+ return categorical_mlp.map_categorical_mlp(
44
+ input_space=bases.VectorSpaceWithBasis.from_names(["input"]),
45
+ output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
46
+ operation=lambda x: x,
47
+ )
48
+ else:
49
+ return None
50
+
51
+ def test_get_longest_path_length_to_node_returns_expected_result(self):
52
+ """Creates a graph and checks the longest path for each node."""
53
+
54
+ # Node IDs:
55
+ # 0 -- 1 -- 2 -- 3 ------------ 4
56
+ # / /
57
+ # 5 -- 6 ---------- 7 -- 8 -- 9
58
+ #
59
+ # 10
60
+ # Expected return values:
61
+ # 0 -- 1 -- 2 -- 3 ------------ 5
62
+ # / /
63
+ # 0 -- 1 ---------- 2 -- 3 -- 4
64
+ #
65
+ # -1
66
+
67
+ graph = nx.DiGraph()
68
+ node_ids = list(range(11))
69
+ expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1]
70
+ for node_id, res in zip(node_ids, expected_results):
71
+ graph.add_node(
72
+ node_id, **{
73
+ nodes.ID: node_id,
74
+ nodes.EXPR: rasp.ConstantSOp(1),
75
+ "expected_result": res
76
+ })
77
+ graph.add_edge(0, 1)
78
+ graph.add_edge(1, 2)
79
+ graph.add_edge(2, 3)
80
+ graph.add_edge(3, 4)
81
+ graph.add_edge(5, 6)
82
+ graph.add_edge(6, 7)
83
+ graph.add_edge(7, 8)
84
+ graph.add_edge(8, 9)
85
+ graph.add_edge(6, 3)
86
+ graph.add_edge(9, 4)
87
+ sources = [graph.nodes[0], graph.nodes[5]]
88
+
89
+ for node_id, node in graph.nodes.items():
90
+ result = craft_graph_to_model._get_longest_path_length_to_node(
91
+ graph, sources, node)
92
+ self.assertEqual(result, node["expected_result"])
93
+
94
+ def test_allocate_modules_to_layers_returns_expected_result(self):
95
+ """Creates a graph and checks if the correct layer assignment is returned."""
96
+
97
+ # Computation Graph:
98
+ # INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT
99
+ # / / /
100
+ # INPUT -- MLP --- MLP ATTN
101
+ # \ /
102
+ # ATTN
103
+ # Node IDs:
104
+ # 0 -- 1 -- 2 -- 3 -- 4 -- 5
105
+ # / / /
106
+ # 6 -- 7 ---- 8 9
107
+ # \ /
108
+ # 10
109
+ # Expected layer allocation:
110
+ # -1 -- 0 -- 3 -- 4 -- 7 -- -1
111
+ # / / /
112
+ # -1 -- 1 --- 3 6
113
+ # \ /
114
+ # 4
115
+
116
+ graph = nx.DiGraph()
117
+ node_ids = list(range(11))
118
+ types = [
119
+ "INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP",
120
+ "ATTN", "ATTN"
121
+ ]
122
+ expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4]
123
+ for node_id, node_type, res in zip(node_ids, types, expected_results):
124
+ graph.add_node(
125
+ node_id, **{
126
+ nodes.ID: node_id,
127
+ nodes.EXPR: rasp.ConstantSOp(1),
128
+ nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
129
+ "expected_result": res
130
+ })
131
+
132
+ graph.add_edge(0, 1)
133
+ graph.add_edge(1, 2)
134
+ graph.add_edge(2, 3)
135
+ graph.add_edge(3, 4)
136
+ graph.add_edge(4, 5)
137
+ graph.add_edge(6, 7)
138
+ graph.add_edge(7, 2)
139
+ graph.add_edge(7, 8)
140
+ graph.add_edge(8, 3)
141
+ graph.add_edge(8, 10)
142
+ graph.add_edge(9, 4)
143
+ graph.add_edge(10, 9)
144
+
145
+ craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
146
+ graph=graph,
147
+ sink=graph.nodes[10],
148
+ sources=[graph.nodes[0], graph.nodes[6]])
149
+
150
+ layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
151
+ craft_graph.graph, craft_graph.sources)
152
+ for node_id, node in graph.nodes.items():
153
+ self.assertEqual(layer_allocation[node_id], node["expected_result"])
154
+
155
+ def test_allocate_modules_to_layers_returns_expected_result_for_chain(self):
156
+ """Tests a chain of alternating attention layers and MLPs."""
157
+
158
+ # Computation Graph:
159
+ # INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT
160
+ # Node IDs:
161
+ # 0 -- 1 -- 2 -- 3 -- 4 -- 5
162
+ # Expected layer allocation:
163
+ # -1 -- 0 -- 1 -- 2 -- 3 -- -1
164
+
165
+ graph = nx.DiGraph()
166
+ node_ids = list(range(11))
167
+ types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"]
168
+ expected_results = [-1, 0, 1, 2, 3, -1]
169
+ for node_id, node_type, res in zip(node_ids, types, expected_results):
170
+ graph.add_node(
171
+ node_id, **{
172
+ nodes.ID: node_id,
173
+ nodes.EXPR: rasp.ConstantSOp(1),
174
+ nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
175
+ "expected_result": res
176
+ })
177
+
178
+ graph.add_edge(0, 1)
179
+ graph.add_edge(1, 2)
180
+ graph.add_edge(2, 3)
181
+ graph.add_edge(3, 4)
182
+ graph.add_edge(4, 5)
183
+
184
+ craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
185
+ graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]])
186
+
187
+ layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
188
+ craft_graph.graph, craft_graph.sources)
189
+ for node_id, node in graph.nodes.items():
190
+ self.assertEqual(layer_allocation[node_id], node["expected_result"])
191
+
192
+
193
+ if __name__ == "__main__":
194
+ absltest.main()
compiler/craft_model_to_transformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Convert craft model into transformer with the correct input/output spaces."""
16
+
17
+ import networkx as nx
18
+ from tracr.compiler import assemble
19
+ from tracr.compiler import nodes
20
+ from tracr.craft import bases
21
+ from tracr.craft import transformers
22
+ from tracr.rasp import rasp
23
+ from tracr.transformer import encoder
24
+
25
+
26
+ def craft_model_to_transformer(
27
+ craft_model: transformers.SeriesWithResiduals,
28
+ graph: nx.DiGraph,
29
+ sink: nodes.Node,
30
+ max_seq_len: int,
31
+ compiler_bos: str,
32
+ compiler_pad: str,
33
+ causal: bool = False,
34
+ ) -> assemble.AssembledTransformerModel:
35
+ """Turn a craft model into a transformer model."""
36
+
37
+ # Add the compiler BOS token.
38
+ tokens_value_set = (
39
+ graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union(
40
+ {compiler_bos, compiler_pad}))
41
+ tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label,
42
+ tokens_value_set)
43
+
44
+ indices_space = bases.VectorSpaceWithBasis.from_values(
45
+ rasp.indices.label, range(max_seq_len))
46
+
47
+ categorical_output = rasp.is_categorical(sink[nodes.EXPR])
48
+ output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS])
49
+
50
+ assembled_model = assemble.assemble_craft_model(
51
+ craft_model=craft_model,
52
+ tokens_space=tokens_space,
53
+ indices_space=indices_space,
54
+ output_space=output_space,
55
+ categorical_output=categorical_output,
56
+ causal=causal,
57
+ )
58
+
59
+ assembled_model.input_encoder = encoder.CategoricalEncoder(
60
+ basis=tokens_space.basis,
61
+ enforce_bos=compiler_bos is not None,
62
+ bos_token=compiler_bos,
63
+ pad_token=compiler_pad,
64
+ max_seq_len=max_seq_len + 1 if compiler_bos is not None else max_seq_len,
65
+ )
66
+
67
+ if categorical_output:
68
+ assembled_model.output_encoder = encoder.CategoricalEncoder(
69
+ basis=output_space.basis,
70
+ enforce_bos=False,
71
+ bos_token=None,
72
+ pad_token=None)
73
+ else:
74
+ assembled_model.output_encoder = encoder.NumericalEncoder()
75
+
76
+ return assembled_model
compiler/expr_to_craft_graph.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Add craft model blocks to graph of RASPExpr."""
16
+
17
+ from typing import Any, Callable, Optional
18
+
19
+ import networkx as nx
20
+ from tracr.compiler import nodes
21
+ from tracr.craft import bases
22
+ from tracr.craft.chamber import categorical_attn
23
+ from tracr.craft.chamber import categorical_mlp
24
+ from tracr.craft.chamber import numerical_mlp
25
+ from tracr.craft.chamber import selector_width
26
+ from tracr.rasp import rasp
27
+
28
+
29
+ def _transform_fun_to_basis_fun(
30
+ fun: Callable[..., Any],
31
+ output_direction_name: Optional[str] = None) -> Callable[..., Any]:
32
+ """Transforms a function acting on values into one acting on directions."""
33
+
34
+ def bases_fun(*args):
35
+ values = [d.value for d in args]
36
+ result = fun(*values)
37
+ if output_direction_name:
38
+ return bases.BasisDirection(output_direction_name, result)
39
+ return result
40
+
41
+ return bases_fun
42
+
43
+
44
+ def _check_selector_expression(expr, graph):
45
+ """Check graph structure and encodings for an aggregate or selector width."""
46
+ sel_expr = expr.selector
47
+
48
+ # Check graph structure
49
+ assert sel_expr.label in graph.predecessors(expr.label)
50
+ assert sel_expr.keys.label in graph.predecessors(sel_expr.label)
51
+ assert sel_expr.queries.label in graph.predecessors(sel_expr.label)
52
+
53
+ if (not rasp.is_categorical(sel_expr.queries) or
54
+ not rasp.is_categorical(sel_expr.keys)):
55
+ raise ValueError("Selector keys and queries must be categorical.")
56
+
57
+
58
+ def add_craft_components_to_rasp_graph(
59
+ graph: nx.DiGraph,
60
+ bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"),
61
+ one_dir: bases.BasisDirection = bases.BasisDirection("one"),
62
+ causal: bool = False,
63
+ mlp_exactness: float = 100,
64
+ ) -> None:
65
+ """Translates expressions to craft blocks and attaches them to the graph.
66
+
67
+ Sets the `MODEL_BLOCK` attribute for all nodes in `graph`.
68
+
69
+ Args:
70
+ graph: RASP graph with `VALUE_SET` but not `MODEL_BLOCK` attributes.
71
+ bos_dir: Basis direction representing beginning of sequence (bos) token.
72
+ one_dir: Auxiliary basis direction that must contain 1.
73
+ causal: If True, marks attention blocks as causal.
74
+ mlp_exactness: Controls the approximation of the MLP layers.
75
+
76
+ Raises:
77
+ ValueError: On invalid input (if `MODEL_BLOCK` is set already, or
78
+ `VALUE_SET` is not set already)
79
+ NotImplementedError: If the graph contains an unsupported expression.
80
+ """
81
+ one_space = bases.VectorSpaceWithBasis([one_dir])
82
+
83
+ for node_id, node in graph.nodes.items():
84
+ expr = node[nodes.EXPR]
85
+
86
+ if not isinstance(expr, rasp.SOp):
87
+ continue
88
+
89
+ if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]:
90
+ raise ValueError("Input graph cannot have model blocks set already.")
91
+ if nodes.VALUE_SET not in node:
92
+ raise ValueError(
93
+ "Craft components can only be added after basis inference.")
94
+
95
+ if expr is rasp.tokens or expr is rasp.indices:
96
+ block = None
97
+ elif isinstance(expr, rasp.Map):
98
+ inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label]
99
+ assert inner_expr.label in graph.predecessors(node_id)
100
+ input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS])
101
+ output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
102
+
103
+ if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr):
104
+ basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
105
+ block = categorical_mlp.map_categorical_mlp(
106
+ input_space=input_space,
107
+ output_space=output_space,
108
+ operation=basis_fun)
109
+ elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr):
110
+ block = categorical_mlp.map_categorical_to_numerical_mlp(
111
+ input_space=input_space,
112
+ output_space=output_space,
113
+ operation=expr.f,
114
+ )
115
+ elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr):
116
+ block = numerical_mlp.map_numerical_to_categorical_mlp(
117
+ f=expr.f,
118
+ input_space=input_space,
119
+ output_space=output_space,
120
+ input_value_set=inner_node[nodes.VALUE_SET],
121
+ one_space=one_space,
122
+ hidden_name=f"_hidden_{expr.label}_",
123
+ large_number=mlp_exactness)
124
+ elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr):
125
+ block = numerical_mlp.map_numerical_mlp(
126
+ f=expr.f,
127
+ input_space=input_space,
128
+ output_space=output_space,
129
+ input_value_set=inner_node[nodes.VALUE_SET],
130
+ one_space=one_space,
131
+ hidden_name=f"_hidden_{expr.label}_",
132
+ large_number=mlp_exactness)
133
+ else:
134
+ raise NotImplementedError("Map does no support "
135
+ f"in_type '{inner_expr.type}' and"
136
+ f" out_type '{expr.type}'!")
137
+
138
+ elif isinstance(expr, rasp.SequenceMap):
139
+ fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label]
140
+ snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label]
141
+
142
+ # Check graph structure
143
+ assert fst_expr.label in graph.predecessors(node_id)
144
+ assert snd_expr.label in graph.predecessors(node_id)
145
+
146
+ fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS])
147
+ snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS])
148
+ out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
149
+
150
+ if (isinstance(expr, rasp.LinearSequenceMap) and
151
+ not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))):
152
+ raise NotImplementedError("Linear SequenceMap only supports numerical "
153
+ "inputs/outputs.")
154
+ elif (
155
+ not isinstance(expr, rasp.LinearSequenceMap) and
156
+ not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))):
157
+ raise NotImplementedError("(Non-linear) SequenceMap only supports "
158
+ "categorical inputs/outputs.")
159
+
160
+ if isinstance(expr, rasp.LinearSequenceMap):
161
+ assert len(fst_space.basis) == 1
162
+ assert len(snd_space.basis) == 1
163
+ assert len(out_space.basis) == 1
164
+ block = numerical_mlp.linear_sequence_map_numerical_mlp(
165
+ input1_basis_direction=fst_space.basis[0],
166
+ input2_basis_direction=snd_space.basis[0],
167
+ output_basis_direction=out_space.basis[0],
168
+ input1_factor=expr.fst_fac,
169
+ input2_factor=expr.snd_fac,
170
+ hidden_name=f"_hidden_{expr.label}_")
171
+ elif fst_space == snd_space:
172
+ # It's okay to use the local variable expr.f because it is
173
+ # only used within the same loop iteration to create the MLP.
174
+ # pylint: disable=cell-var-from-loop
175
+ basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x),
176
+ expr.label)
177
+ block = categorical_mlp.map_categorical_mlp(
178
+ input_space=fst_space, output_space=out_space, operation=basis_fun)
179
+ else:
180
+ basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
181
+ block = categorical_mlp.sequence_map_categorical_mlp(
182
+ input1_space=fst_space,
183
+ input2_space=snd_space,
184
+ output_space=out_space,
185
+ operation=basis_fun,
186
+ one_space=one_space,
187
+ hidden_name=f"_hidden_{expr.label}_")
188
+ elif isinstance(expr, rasp.Aggregate):
189
+ sel_expr: rasp.Select = expr.selector
190
+ agg_expr: rasp.Aggregate = expr
191
+
192
+ if not isinstance(sel_expr, rasp.Select):
193
+ raise TypeError("Compiling composite Selectors is not supported. "
194
+ f"Got a {sel_expr}.")
195
+
196
+ queries = graph.nodes[sel_expr.queries.label]
197
+ keys = graph.nodes[sel_expr.keys.label]
198
+ sop = graph.nodes[agg_expr.sop.label]
199
+
200
+ _check_selector_expression(expr, graph)
201
+ assert agg_expr.sop.label in graph.predecessors(node_id)
202
+ if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr):
203
+ raise ValueError(
204
+ "sop encoding must match output encoding of the aggregate.")
205
+ if rasp.is_categorical(agg_expr) and agg_expr.default is not None:
206
+ raise ValueError("Default for a categorical aggregate must be None. "
207
+ f"Got {agg_expr.default}")
208
+ if rasp.is_numerical(agg_expr) and agg_expr.default != 0:
209
+ raise ValueError("Default for a numerical aggregate must be 0. "
210
+ f"Got {agg_expr.default}")
211
+
212
+ bos_space = bases.VectorSpaceWithBasis([bos_dir])
213
+ one_space = bases.VectorSpaceWithBasis([one_dir])
214
+ query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
215
+ key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
216
+ value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS])
217
+ output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
218
+
219
+ # Argument order is different in craft / transformers than RASP selectors
220
+ def attn_basis_fn(query: bases.BasisDirection,
221
+ key: bases.BasisDirection) -> bool:
222
+ # It's okay to use the local variable sel_expr because this function is
223
+ # only used within the same loop iteration to create an attention head.
224
+ # pylint: disable=cell-var-from-loop
225
+ selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)
226
+ return selector_basis_fn(key, query)
227
+
228
+ block = categorical_attn.categorical_attn(
229
+ query_space=query_space,
230
+ key_space=key_space,
231
+ value_space=value_space,
232
+ output_space=output_space,
233
+ bos_space=bos_space,
234
+ one_space=one_space,
235
+ attn_fn=attn_basis_fn,
236
+ default_output=output_space.null_vector(),
237
+ causal=causal,
238
+ always_attend_to_bos=False,
239
+ use_bos_for_default_output=True,
240
+ softmax_coldness=100)
241
+ elif isinstance(expr, rasp.SelectorWidth):
242
+ sel_expr = expr.selector
243
+ queries = graph.nodes[sel_expr.queries.label]
244
+ keys = graph.nodes[sel_expr.keys.label]
245
+ _check_selector_expression(expr, graph)
246
+
247
+ bos_space = bases.VectorSpaceWithBasis([bos_dir])
248
+ query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
249
+ key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
250
+ output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
251
+
252
+ # Argument order is different in craft / transformers than RASP selectors
253
+ def attn_basis_fn(query: bases.BasisDirection,
254
+ key: bases.BasisDirection) -> bool:
255
+ # It's okay to use the local variable sel_expr because this function is
256
+ # only used within the same loop iteration to create an attention head.
257
+ selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) # pylint: disable=cell-var-from-loop
258
+ return selector_basis_fn(key, query)
259
+
260
+ block = selector_width.selector_width(
261
+ query_space=query_space,
262
+ key_space=key_space,
263
+ output_space=output_space,
264
+ bos_space=bos_space,
265
+ one_space=one_space,
266
+ attn_fn=attn_basis_fn,
267
+ out_value_set=node[nodes.VALUE_SET],
268
+ categorical_output=rasp.is_categorical(expr),
269
+ causal=False,
270
+ softmax_coldness=100,
271
+ mlp_large_number=mlp_exactness,
272
+ label=expr.label)
273
+ else:
274
+ raise NotImplementedError(f"Expression {expr} cannot be translated to "
275
+ "a model component.")
276
+
277
+ graph.nodes[node_id][nodes.MODEL_BLOCK] = block
compiler/expr_to_craft_graph_test.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for compiler.expr_to_craft_graph."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from tracr.compiler import basis_inference
20
+ from tracr.compiler import expr_to_craft_graph
21
+ from tracr.compiler import lib
22
+ from tracr.compiler import nodes
23
+ from tracr.compiler import rasp_to_graph
24
+ from tracr.craft import bases
25
+ from tracr.craft import transformers
26
+ from tracr.rasp import rasp
27
+
28
+
29
+ class ExprToCraftGraphTest(parameterized.TestCase):
30
+
31
+ def _check_block_types_are_correct(self, graph):
32
+ for _, node in graph.nodes.items():
33
+ expr = node[nodes.EXPR]
34
+ if isinstance(expr, rasp.SOp):
35
+ block = node[nodes.MODEL_BLOCK]
36
+ if isinstance(expr, (rasp.Map, rasp.SequenceMap)):
37
+ self.assertIsInstance(block, transformers.MLP)
38
+ elif isinstance(expr, rasp.Aggregate):
39
+ self.assertIsInstance(block, transformers.AttentionHead)
40
+
41
+ def _get_input_space_from_node(self, node):
42
+ block = node[nodes.MODEL_BLOCK]
43
+ if isinstance(block, transformers.MLP):
44
+ return block.fst.input_space
45
+ elif isinstance(block, transformers.AttentionHead):
46
+ return bases.join_vector_spaces(block.w_qk.left_space,
47
+ block.w_qk.right_space,
48
+ block.w_ov.input_space)
49
+ else:
50
+ return None
51
+
52
+ def _check_spaces_are_consistent(self, graph):
53
+ """Check that for each edge the output is a subspace of the input."""
54
+ for u, v in graph.edges:
55
+ u_node, v_node = graph.nodes[u], graph.nodes[v]
56
+ if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance(
57
+ v_node[nodes.EXPR], rasp.SOp):
58
+ u_out_basis = u_node[nodes.OUTPUT_BASIS]
59
+ u_out_space = bases.VectorSpaceWithBasis(u_out_basis)
60
+ v_in_space = self._get_input_space_from_node(v_node)
61
+ self.assertTrue(u_out_space.issubspace(v_in_space))
62
+
63
+ @parameterized.named_parameters(
64
+ dict(
65
+ testcase_name="single_map",
66
+ program=rasp.Map(lambda x: x + 1, rasp.tokens)),
67
+ dict(
68
+ testcase_name="single_sequence_map",
69
+ program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens,
70
+ rasp.indices)),
71
+ dict(
72
+ testcase_name="single_select_aggregate",
73
+ program=rasp.Aggregate(
74
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
75
+ rasp.tokens,
76
+ )),
77
+ dict(testcase_name="reverse", program=lib.make_reverse(rasp.tokens)),
78
+ dict(testcase_name="length", program=lib.make_length()))
79
+ def test_compiling_rasp_programs(self, program):
80
+ vocab = {0, 1, 2}
81
+ extracted = rasp_to_graph.extract_rasp_graph(program)
82
+ basis_inference.infer_bases(
83
+ extracted.graph,
84
+ extracted.sink,
85
+ vocab,
86
+ max_seq_len=3,
87
+ )
88
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
89
+ self._check_block_types_are_correct(extracted.graph)
90
+ self._check_spaces_are_consistent(extracted.graph)
91
+
92
+ def test_add_craft_components_raises_value_error_if_called_before_basis_inference(
93
+ self):
94
+ program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))
95
+ extracted = rasp_to_graph.extract_rasp_graph(program)
96
+
97
+ with self.assertRaisesRegex(
98
+ ValueError,
99
+ r"^.*Craft components can only be added after basis inference.*$"):
100
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
101
+
102
+ def test_add_craft_components_raises_value_error_if_called_twice(self):
103
+ vocab = {0, 1, 2}
104
+ program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))
105
+ extracted = rasp_to_graph.extract_rasp_graph(program)
106
+
107
+ basis_inference.infer_bases(
108
+ extracted.graph,
109
+ extracted.sink,
110
+ vocab,
111
+ max_seq_len=1,
112
+ )
113
+
114
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
115
+ with self.assertRaisesRegex(
116
+ ValueError, r"^.*Input graph cannot have model blocks set already.*$"):
117
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ absltest.main()
compiler/lib.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """RASP programs only using the subset of RASP supported by the compiler."""
16
+
17
+ from typing import Sequence
18
+
19
+ from tracr.rasp import rasp
20
+
21
+ ### Programs that work only under non-causal evaluation.
22
+
23
+
24
+ def make_length() -> rasp.SOp:
25
+ """Creates the `length` SOp using selector width primitive.
26
+
27
+ Example usage:
28
+ length = make_length()
29
+ length("abcdefg")
30
+ >> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]
31
+
32
+ Returns:
33
+ length: SOp mapping an input to a sequence, where every element
34
+ is the length of that sequence.
35
+ """
36
+ all_true_selector = rasp.Select(
37
+ rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
38
+ return rasp.SelectorWidth(all_true_selector).named("length")
39
+
40
+
41
+ length = make_length()
42
+
43
+
44
+ def make_reverse(sop: rasp.SOp) -> rasp.SOp:
45
+ """Create an SOp that reverses a sequence, using length primitive.
46
+
47
+ Example usage:
48
+ reverse = make_reverse(rasp.tokens)
49
+ reverse("Hello")
50
+ >> ['o', 'l', 'l', 'e', 'H']
51
+
52
+ Args:
53
+ sop: an SOp
54
+
55
+ Returns:
56
+ reverse : SOp that reverses the input sequence.
57
+ """
58
+ opp_idx = (length - rasp.indices).named("opp_idx")
59
+ opp_idx = (opp_idx - 1).named("opp_idx-1")
60
+ reverse_selector = rasp.Select(rasp.indices, opp_idx,
61
+ rasp.Comparison.EQ).named("reverse_selector")
62
+ return rasp.Aggregate(reverse_selector, sop).named("reverse")
63
+
64
+
65
+ def make_pair_balance(sop: rasp.SOp, open_token: str,
66
+ close_token: str) -> rasp.SOp:
67
+ """Return fraction of previous open tokens minus the fraction of close tokens.
68
+
69
+ (As implemented in the RASP paper.)
70
+
71
+ If the outputs are always non-negative and end in 0, that implies the input
72
+ has balanced parentheses.
73
+
74
+ Example usage:
75
+ num_l = make_pair_balance(rasp.tokens, "(", ")")
76
+ num_l("a()b(c))")
77
+ >> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8]
78
+
79
+ Args:
80
+ sop: Input SOp.
81
+ open_token: Token that counts positive.
82
+ close_token: Token that counts negative.
83
+
84
+ Returns:
85
+ pair_balance: SOp mapping an input to a sequence, where every element
86
+ is the fraction of previous open tokens minus previous close tokens.
87
+ """
88
+ bools_open = rasp.numerical(sop == open_token).named("bools_open")
89
+ opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens")
90
+
91
+ bools_close = rasp.numerical(sop == close_token).named("bools_close")
92
+ closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes")
93
+
94
+ pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1))
95
+ return pair_balance.named("pair_balance")
96
+
97
+
98
+ def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp:
99
+ """Returns 1 if a set of parentheses are balanced, 0 else.
100
+
101
+ (As implemented in the RASP paper.)
102
+
103
+ Example usage:
104
+ shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"])
105
+ shuffle_dyck2("({)}")
106
+ >> [1, 1, 1, 1]
107
+ shuffle_dyck2("(){)}")
108
+ >> [0, 0, 0, 0, 0]
109
+
110
+ Args:
111
+ pairs: List of pairs of open and close tokens that each should be balanced.
112
+ """
113
+ assert len(pairs) >= 1
114
+
115
+ # Compute running balance of each type of parenthesis
116
+ balances = []
117
+ for pair in pairs:
118
+ assert len(pair) == 2
119
+ open_token, close_token = pair
120
+ balance = make_pair_balance(
121
+ rasp.tokens, open_token=open_token,
122
+ close_token=close_token).named(f"balance_{pair}")
123
+ balances.append(balance)
124
+
125
+ # Check if balances where negative anywhere -> parentheses not balanced
126
+ any_negative = balances[0] < 0
127
+ for balance in balances[1:]:
128
+ any_negative = any_negative | (balance < 0)
129
+
130
+ # Convert to numerical SOp
131
+ any_negative = rasp.numerical(rasp.Map(lambda x: x,
132
+ any_negative)).named("any_negative")
133
+
134
+ select_all = rasp.Select(rasp.indices, rasp.indices,
135
+ rasp.Comparison.TRUE).named("select_all")
136
+ has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative,
137
+ default=0)).named("has_neg")
138
+
139
+ # Check if all balances are 0 at the end -> closed all parentheses
140
+ all_zero = balances[0] == 0
141
+ for balance in balances[1:]:
142
+ all_zero = all_zero & (balance == 0)
143
+
144
+ select_last = rasp.Select(rasp.indices, length - 1,
145
+ rasp.Comparison.EQ).named("select_last")
146
+ last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero")
147
+
148
+ not_has_neg = (~has_neg).named("not_has_neg")
149
+ return (last_zero & not_has_neg).named("shuffle_dyck")
150
+
151
+
152
+ def make_shuffle_dyck2() -> rasp.SOp:
153
+ return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2")
154
+
155
+
156
+ def make_hist() -> rasp.SOp:
157
+ """Returns the number of times each token occurs in the input.
158
+
159
+ (As implemented in the RASP paper.)
160
+
161
+ Example usage:
162
+ hist = make_hist()
163
+ hist("abac")
164
+ >> [2, 1, 2, 1]
165
+ """
166
+ same_tok = rasp.Select(rasp.tokens, rasp.tokens,
167
+ rasp.Comparison.EQ).named("same_tok")
168
+ return rasp.SelectorWidth(same_tok).named("hist")
169
+
170
+
171
+ def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
172
+ """Returns vals sorted by < relation on keys.
173
+
174
+ Only supports unique keys.
175
+
176
+ Example usage:
177
+ sort = make_sort(rasp.tokens, rasp.tokens)
178
+ sort([2, 4, 3, 1])
179
+ >> [1, 2, 3, 4]
180
+
181
+ Args:
182
+ vals: Values to sort.
183
+ keys: Keys for sorting.
184
+ """
185
+ smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
186
+ target_pos = rasp.SelectorWidth(smaller).named("target_pos")
187
+ sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
188
+ return rasp.Aggregate(sel_new, vals).named("sort")
189
+
190
+
191
+ def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
192
+ min_key: float) -> rasp.SOp:
193
+ """Returns vals sorted by < relation on keys, which don't need to be unique.
194
+
195
+ The implementation differs from the RASP paper, as it avoids using
196
+ compositions of selectors to break ties. Instead, it uses the arguments
197
+ max_seq_len and min_key to ensure the keys are unique.
198
+
199
+ Note that this approach only works for numerical keys.
200
+
201
+ Example usage:
202
+ sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
203
+ sort([2, 4, 3, 1])
204
+ >> [1, 2, 3, 4]
205
+ sort([2, 4, 1, 2])
206
+ >> [1, 2, 2, 4]
207
+
208
+ Args:
209
+ vals: Values to sort.
210
+ keys: Keys for sorting.
211
+ max_seq_len: Maximum sequence length (used to ensure keys are unique)
212
+ min_key: Minimum key value (used to ensure keys are unique)
213
+
214
+ Returns:
215
+ Output SOp of sort program.
216
+ """
217
+ keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
218
+ rasp.indices)
219
+ return make_sort_unique(vals, keys)
220
+
221
+
222
+ def make_sort_freq(max_seq_len: int) -> rasp.SOp:
223
+ """Returns tokens sorted by the frequency they appear in the input.
224
+
225
+ Tokens the appear the same amount of times are output in the same order as in
226
+ the input.
227
+
228
+ Example usage:
229
+ sort = make_sort_freq(rasp.tokens, rasp.tokens, 5)
230
+ sort([2, 4, 2, 1])
231
+ >> [2, 2, 4, 1]
232
+
233
+ Args:
234
+ max_seq_len: Maximum sequence length (used to ensure keys are unique)
235
+ """
236
+ hist = -1 * make_hist().named("hist")
237
+ return make_sort(
238
+ rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq")
239
+
240
+
241
+ ### Programs that work under both causal and regular evaluation.
242
+
243
+
244
+ def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp:
245
+ """Count the fraction of previous tokens where a specific condition was True.
246
+
247
+ (As implemented in the RASP paper.)
248
+
249
+ Example usage:
250
+ num_l = make_frac_prevs(rasp.tokens=="l")
251
+ num_l("hello")
252
+ >> [0, 0, 1/3, 1/2, 2/5]
253
+
254
+ Args:
255
+ bools: SOp mapping a sequence to a sequence of booleans.
256
+
257
+ Returns:
258
+ frac_prevs: SOp mapping an input to a sequence, where every element
259
+ is the fraction of previous "True" tokens.
260
+ """
261
+ bools = rasp.numerical(bools)
262
+ prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
263
+ return rasp.numerical(rasp.Aggregate(prevs, bools,
264
+ default=0)).named("frac_prevs")
265
+
266
+
267
+ def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
268
+ """Returns the sop, shifted by `offset`, None-padded."""
269
+ select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
270
+ lambda k, q: q == k + offset)
271
+ out = rasp.Aggregate(select_off_by_offset, sop, default=None)
272
+ return out.named(f"shift_by({offset})")
273
+
274
+
275
+ def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp:
276
+ """Returns an SOp which is True at the final element of the pattern.
277
+
278
+ The first len(pattern) - 1 elements of the output SOp are None-padded.
279
+
280
+ detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]
281
+
282
+ Args:
283
+ sop: the SOp in which to look for patterns.
284
+ pattern: a sequence of values to look for.
285
+
286
+ Returns:
287
+ a sop which detects the pattern.
288
+ """
289
+
290
+ if len(pattern) < 1:
291
+ raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}")
292
+
293
+ # detectors[i] will be a boolean-valued SOp which is true at position j iff
294
+ # the i'th (from the end) element of the pattern was detected at position j-i.
295
+ detectors = []
296
+ for i, element in enumerate(reversed(pattern)):
297
+ detector = sop == element
298
+ if i != 0:
299
+ detector = shift_by(i, detector)
300
+ detectors.append(detector)
301
+
302
+ # All that's left is to take the AND over all detectors.
303
+ pattern_detected = detectors.pop()
304
+ while detectors:
305
+ pattern_detected = pattern_detected & detectors.pop()
306
+
307
+ return pattern_detected.named(f"detect_pattern({pattern})")
308
+
309
+
310
+ def make_count_less_freq(n: int) -> rasp.SOp:
311
+ """Returns how many tokens appear fewer than n times in the input.
312
+
313
+ The output sequence contains this count in each position.
314
+
315
+ Example usage:
316
+ count_less_freq = make_count_less_freq(2)
317
+ count_less_freq(["a", "a", "a", "b", "b", "c"])
318
+ >> [3, 3, 3, 3, 3, 3]
319
+ count_less_freq(["a", "a", "c", "b", "b", "c"])
320
+ >> [6, 6, 6, 6, 6, 6]
321
+
322
+ Args:
323
+ n: Integer to compare token frequences to.
324
+ """
325
+ hist = make_hist().named("hist")
326
+ select_less = rasp.Select(hist, hist,
327
+ lambda x, y: x <= n).named("select_less")
328
+ return rasp.SelectorWidth(select_less).named("count_less_freq")
329
+
330
+
331
+ def make_count(sop, token):
332
+ """Returns the count of `token` in `sop`.
333
+
334
+ The output sequence contains this count in each position.
335
+
336
+ Example usage:
337
+ count = make_count(tokens, "a")
338
+ count(["a", "a", "a", "b", "b", "c"])
339
+ >> [3, 3, 3, 3, 3, 3]
340
+ count(["c", "a", "b", "c"])
341
+ >> [1, 1, 1, 1]
342
+
343
+ Args:
344
+ sop: Sop to count tokens in.
345
+ token: Token to count.
346
+ """
347
+ return rasp.SelectorWidth(rasp.Select(
348
+ sop, sop, lambda k, q: k == token)).named(f"count_{token}")
349
+
350
+
351
+ def make_nary_sequencemap(f, *sops):
352
+ """Returns an SOp that simulates an n-ary SequenceMap.
353
+
354
+ Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n
355
+ into a single SOp arguments that takes n-tuples as value. The n-ary sequence
356
+ map implementing f is then a Map on this resulting SOp.
357
+
358
+ Note that the intermediate variables representing tuples of varying length
359
+ will be encoded categorically, and can become very high-dimensional. So,
360
+ using this function might lead to very large compiled models.
361
+
362
+ Args:
363
+ f: Function with n arguments.
364
+ *sops: Sequence of SOps, one for each argument of f.
365
+ """
366
+ values, *sops = sops
367
+ for sop in sops:
368
+ # x is a single entry in the first iteration but a tuple in later iterations
369
+ values = rasp.SequenceMap(
370
+ lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop)
371
+ return rasp.Map(lambda args: f(*args), values)
compiler/lib_test.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for compiler.lib."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from tracr.compiler import test_cases
20
+ from tracr.rasp import causal_eval
21
+ from tracr.rasp import rasp
22
+
23
+
24
+ class LibTest(parameterized.TestCase):
25
+
26
+ @parameterized.named_parameters(*test_cases.TEST_CASES)
27
+ def test_program_produces_expected_output(self, program, test_input,
28
+ expected_output, **kwargs):
29
+ del kwargs
30
+ self.assertEqual(rasp.evaluate(program, test_input), expected_output)
31
+
32
+ @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES)
33
+ def test_causal_program_produces_expected_output(self, program, test_input,
34
+ expected_output, **kwargs):
35
+ del kwargs
36
+ self.assertEqual(causal_eval.evaluate(program, test_input), expected_output)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ absltest.main()
compiler/nodes.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Documents the data stored in nodes after each compiler pass."""
16
+
17
+ from typing import Any
18
+
19
+ Node = dict[str, Any]
20
+ NodeID = str
21
+
22
+ # RASP -> Graph
23
+ ID = "ID" # unique ID of the node
24
+ EXPR = "EXPR" # the RASPExpr of the node
25
+
26
+ # Basis inference
27
+ # Note that only S-Op expressions will have these keys set.
28
+ VALUE_SET = "VALUE_SET" # possible values taken on by this SOp.
29
+ OUTPUT_BASIS = "OUTPUT_BASIS" # the corresponding named basis.
30
+
31
+ # RASP Graph -> Craft Graph
32
+ MODEL_BLOCK = "MODEL_BLOCK" # craft block representing a RASPExpr
compiler/rasp_to_craft_integration_test.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Integration tests for the RASP -> craft stages of the compiler."""
16
+
17
+ import unittest
18
+
19
+ from absl.testing import absltest
20
+ from absl.testing import parameterized
21
+ import numpy as np
22
+ from tracr.compiler import basis_inference
23
+ from tracr.compiler import craft_graph_to_model
24
+ from tracr.compiler import expr_to_craft_graph
25
+ from tracr.compiler import nodes
26
+ from tracr.compiler import rasp_to_graph
27
+ from tracr.compiler import test_cases
28
+ from tracr.craft import bases
29
+ from tracr.craft import tests_common
30
+ from tracr.rasp import rasp
31
+
32
+ _BOS_DIRECTION = "rasp_to_transformer_integration_test_BOS"
33
+ _ONE_DIRECTION = "rasp_to_craft_integration_test_ONE"
34
+
35
+
36
+ def _make_input_space(vocab, max_seq_len):
37
+ tokens_space = bases.VectorSpaceWithBasis.from_values("tokens", vocab)
38
+ indices_space = bases.VectorSpaceWithBasis.from_values(
39
+ "indices", range(max_seq_len))
40
+ one_space = bases.VectorSpaceWithBasis.from_names([_ONE_DIRECTION])
41
+ bos_space = bases.VectorSpaceWithBasis.from_names([_BOS_DIRECTION])
42
+ input_space = bases.join_vector_spaces(tokens_space, indices_space, one_space,
43
+ bos_space)
44
+
45
+ return input_space
46
+
47
+
48
+ def _embed_input(input_seq, input_space):
49
+ bos_vec = input_space.vector_from_basis_direction(
50
+ bases.BasisDirection(_BOS_DIRECTION))
51
+ one_vec = input_space.vector_from_basis_direction(
52
+ bases.BasisDirection(_ONE_DIRECTION))
53
+ embedded_input = [bos_vec + one_vec]
54
+ for i, val in enumerate(input_seq):
55
+ i_vec = input_space.vector_from_basis_direction(
56
+ bases.BasisDirection("indices", i))
57
+ val_vec = input_space.vector_from_basis_direction(
58
+ bases.BasisDirection("tokens", val))
59
+ embedded_input.append(i_vec + val_vec + one_vec)
60
+ return bases.VectorInBasis.stack(embedded_input)
61
+
62
+
63
+ def _embed_output(output_seq, output_space, categorical_output):
64
+ embedded_output = []
65
+ output_label = output_space.basis[0].name
66
+ for x in output_seq:
67
+ if x is None:
68
+ out_vec = output_space.null_vector()
69
+ elif categorical_output:
70
+ out_vec = output_space.vector_from_basis_direction(
71
+ bases.BasisDirection(output_label, x))
72
+ else:
73
+ out_vec = x * output_space.vector_from_basis_direction(
74
+ output_space.basis[0])
75
+ embedded_output.append(out_vec)
76
+ return bases.VectorInBasis.stack(embedded_output)
77
+
78
+
79
+ class CompilerIntegrationTest(tests_common.VectorFnTestCase):
80
+
81
+ @parameterized.named_parameters(
82
+ dict(
83
+ testcase_name="map",
84
+ program=rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))),
85
+ dict(
86
+ testcase_name="sequence_map",
87
+ program=rasp.categorical(
88
+ rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.indices))),
89
+ dict(
90
+ testcase_name="sequence_map_with_same_input",
91
+ program=rasp.categorical(
92
+ rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens))),
93
+ dict(
94
+ testcase_name="select_aggregate",
95
+ program=rasp.Aggregate(
96
+ rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
97
+ rasp.Map(lambda x: 1, rasp.tokens))))
98
+ def test_rasp_program_and_craft_model_produce_same_output(self, program):
99
+ vocab = {0, 1, 2}
100
+ max_seq_len = 3
101
+
102
+ extracted = rasp_to_graph.extract_rasp_graph(program)
103
+ basis_inference.infer_bases(
104
+ extracted.graph,
105
+ extracted.sink,
106
+ vocab,
107
+ max_seq_len=max_seq_len,
108
+ )
109
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(
110
+ extracted.graph,
111
+ bos_dir=bases.BasisDirection(_BOS_DIRECTION),
112
+ one_dir=bases.BasisDirection(_ONE_DIRECTION),
113
+ )
114
+ model = craft_graph_to_model.craft_graph_to_model(extracted.graph,
115
+ extracted.sources)
116
+ input_space = _make_input_space(vocab, max_seq_len)
117
+ output_space = bases.VectorSpaceWithBasis(
118
+ extracted.sink[nodes.OUTPUT_BASIS])
119
+
120
+ for val in vocab:
121
+ test_input = _embed_input([val], input_space)
122
+ rasp_output = program([val])
123
+ expected_output = _embed_output(
124
+ output_seq=rasp_output,
125
+ output_space=output_space,
126
+ categorical_output=True)
127
+ test_output = model.apply(test_input).project(output_space)
128
+ self.assertVectorAllClose(
129
+ tests_common.strip_bos_token(test_output), expected_output)
130
+
131
+ @parameterized.named_parameters(*test_cases.TEST_CASES)
132
+ def test_compiled_models_produce_expected_output(self, program, vocab,
133
+ test_input, expected_output,
134
+ max_seq_len, **kwargs):
135
+ del kwargs
136
+ categorical_output = rasp.is_categorical(program)
137
+
138
+ extracted = rasp_to_graph.extract_rasp_graph(program)
139
+ basis_inference.infer_bases(
140
+ extracted.graph,
141
+ extracted.sink,
142
+ vocab,
143
+ max_seq_len=max_seq_len,
144
+ )
145
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(
146
+ extracted.graph,
147
+ bos_dir=bases.BasisDirection(_BOS_DIRECTION),
148
+ one_dir=bases.BasisDirection(_ONE_DIRECTION),
149
+ )
150
+ model = craft_graph_to_model.craft_graph_to_model(extracted.graph,
151
+ extracted.sources)
152
+ input_space = _make_input_space(vocab, max_seq_len)
153
+ output_space = bases.VectorSpaceWithBasis(
154
+ extracted.sink[nodes.OUTPUT_BASIS])
155
+ if not categorical_output:
156
+ self.assertLen(output_space.basis, 1)
157
+
158
+ test_input_vector = _embed_input(test_input, input_space)
159
+ expected_output_vector = _embed_output(
160
+ output_seq=expected_output,
161
+ output_space=output_space,
162
+ categorical_output=categorical_output)
163
+ test_output = model.apply(test_input_vector).project(output_space)
164
+ self.assertVectorAllClose(
165
+ tests_common.strip_bos_token(test_output), expected_output_vector)
166
+
167
+ @unittest.expectedFailure
168
+ def test_setting_default_values_can_lead_to_wrong_outputs_in_compiled_model(
169
+ self, program):
170
+ # This is an example program in which setting a default value for aggregate
171
+ # writes a value to the bos token position, which interfers with a later
172
+ # aggregate operation causing the compiled model to have the wrong output.
173
+
174
+ vocab = {"a", "b"}
175
+ test_input = ["a"]
176
+ max_seq_len = 2
177
+
178
+ # RASP: [False, True]
179
+ # compiled: [False, False, True]
180
+ not_a = rasp.Map(lambda x: x != "a", rasp.tokens)
181
+
182
+ # RASP:
183
+ # [[True, False],
184
+ # [False, False]]
185
+ # compiled:
186
+ # [[False,True, False],
187
+ # [True, False, False]]
188
+ sel1 = rasp.Select(rasp.tokens, rasp.tokens,
189
+ lambda k, q: k == "a" and q == "a")
190
+
191
+ # RASP: [False, True]
192
+ # compiled: [True, False, True]
193
+ agg1 = rasp.Aggregate(sel1, not_a, default=True)
194
+
195
+ # RASP:
196
+ # [[False, True]
197
+ # [True, True]]
198
+ # compiled:
199
+ # [[True, False, False]
200
+ # [True, False, False]]
201
+ # because pre-softmax we get
202
+ # [[1.5, 1, 1]
203
+ # [1.5, 1, 1]]
204
+ # instead of
205
+ # [[0.5, 1, 1]
206
+ # [0.5, 1, 1]]
207
+ # Because agg1 = True is stored on the BOS token position
208
+ sel2 = rasp.Select(agg1, agg1, lambda k, q: k or q)
209
+
210
+ # RASP: [1, 0.5]
211
+ # compiled
212
+ # [1, 1, 1]
213
+ program = rasp.numerical(
214
+ rasp.Aggregate(sel2, rasp.numerical(not_a), default=1))
215
+ expected_output = [1, 0.5]
216
+
217
+ # RASP program gives the correct output
218
+ program_output = program(test_input)
219
+ np.testing.assert_allclose(program_output, expected_output)
220
+
221
+ extracted = rasp_to_graph.extract_rasp_graph(program)
222
+ basis_inference.infer_bases(
223
+ extracted.graph,
224
+ extracted.sink,
225
+ vocab,
226
+ max_seq_len=max_seq_len,
227
+ )
228
+ expr_to_craft_graph.add_craft_components_to_rasp_graph(
229
+ extracted.graph,
230
+ bos_dir=bases.BasisDirection(_BOS_DIRECTION),
231
+ one_dir=bases.BasisDirection(_ONE_DIRECTION),
232
+ )
233
+ model = craft_graph_to_model.craft_graph_to_model(extracted.graph,
234
+ extracted.sources)
235
+
236
+ input_space = _make_input_space(vocab, max_seq_len)
237
+ output_space = bases.VectorSpaceWithBasis(
238
+ extracted.sink[nodes.OUTPUT_BASIS])
239
+
240
+ test_input_vector = _embed_input(test_input, input_space)
241
+ expected_output_vector = _embed_output(
242
+ output_seq=expected_output,
243
+ output_space=output_space,
244
+ categorical_output=True)
245
+ compiled_model_output = model.apply(test_input_vector).project(output_space)
246
+
247
+ # Compiled craft model gives correct output
248
+ self.assertVectorAllClose(
249
+ tests_common.strip_bos_token(compiled_model_output),
250
+ expected_output_vector)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ absltest.main()
compiler/rasp_to_graph.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Converting a RaspExpr to a graph."""
16
+
17
+ import dataclasses
18
+ import queue
19
+
20
+ import networkx as nx
21
+ from tracr.compiler import nodes
22
+ from tracr.rasp import rasp
23
+
24
+ Node = nodes.Node
25
+ NodeID = nodes.NodeID
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class ExtractRaspGraphOutput:
30
+ graph: nx.DiGraph
31
+ sink: Node # the program's output.
32
+ sources: list[Node] # the primitive S-Ops.
33
+
34
+
35
+ def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
36
+ """Converts a RASP program into a graph representation."""
37
+ expr_queue = queue.Queue()
38
+ graph = nx.DiGraph()
39
+ sources: list[NodeID] = []
40
+
41
+ def ensure_node(expr: rasp.RASPExpr) -> NodeID:
42
+ """Finds or creates a graph node corresponding to expr; returns its ID."""
43
+ node_id = expr.label
44
+ if node_id not in graph:
45
+ graph.add_node(node_id, **{nodes.ID: node_id, nodes.EXPR: expr})
46
+
47
+ return node_id
48
+
49
+ # Breadth-first search over the RASP expression graph.
50
+
51
+ def visit_raspexpr(expr: rasp.RASPExpr):
52
+ parent_id = ensure_node(expr)
53
+
54
+ for child_expr in expr.children:
55
+ expr_queue.put(child_expr)
56
+ child_id = ensure_node(child_expr)
57
+ graph.add_edge(child_id, parent_id)
58
+
59
+ if not expr.children:
60
+ sources.append(graph.nodes[parent_id])
61
+
62
+ expr_queue.put(tip)
63
+ sink = graph.nodes[ensure_node(tip)]
64
+ while not expr_queue.empty():
65
+ visit_raspexpr(expr_queue.get())
66
+
67
+ return ExtractRaspGraphOutput(graph=graph, sink=sink, sources=sources)
compiler/rasp_to_graph_test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for compiler.rasp_to_graph."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from tracr.compiler import nodes
20
+ from tracr.compiler import rasp_to_graph
21
+ from tracr.rasp import rasp
22
+
23
+
24
+ class ExtractRaspGraphTest(parameterized.TestCase):
25
+
26
+ def test_primitives_have_no_edges(self):
27
+ tokens_graph = rasp_to_graph.extract_rasp_graph(rasp.tokens).graph
28
+ self.assertEmpty(tokens_graph.edges)
29
+
30
+ indices_graph = rasp_to_graph.extract_rasp_graph(rasp.indices).graph
31
+ self.assertEmpty(indices_graph.edges)
32
+
33
+ full_graph = rasp_to_graph.extract_rasp_graph(rasp.Full(1)).graph
34
+ self.assertEmpty(full_graph.edges)
35
+
36
+ def test_one_edge(self):
37
+ program = rasp.Map(lambda x: x + 1, rasp.tokens)
38
+
39
+ graph = rasp_to_graph.extract_rasp_graph(program).graph
40
+
41
+ self.assertLen(graph.edges, 1)
42
+ (u, v), = graph.edges
43
+ self.assertEqual(graph.nodes[u][nodes.EXPR], rasp.tokens)
44
+ self.assertEqual(graph.nodes[v][nodes.EXPR], program)
45
+
46
+ def test_aggregate(self):
47
+ program = rasp.Aggregate(
48
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
49
+ rasp.indices,
50
+ )
51
+
52
+ extracted = rasp_to_graph.extract_rasp_graph(program)
53
+
54
+ # Expected graph:
55
+ #
56
+ # indices \ --------
57
+ # \ \
58
+ # select -- program
59
+ # tokens /
60
+
61
+ self.assertLen(extracted.graph.edges, 4)
62
+ self.assertEqual(extracted.sink[nodes.EXPR], program)
63
+ for source in extracted.sources:
64
+ self.assertIn(
65
+ source[nodes.EXPR],
66
+ [rasp.tokens, rasp.indices],
67
+ )
68
+
69
+
70
+ if __name__ == "__main__":
71
+ absltest.main()
compiler/rasp_to_transformer_integration_test.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Integration tests for the full RASP -> transformer compilation."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import jax
20
+ import numpy as np
21
+
22
+ from tracr.compiler import compiling
23
+ from tracr.compiler import lib
24
+ from tracr.compiler import test_cases
25
+ from tracr.craft import tests_common
26
+ from tracr.rasp import rasp
27
+
28
+ _COMPILER_BOS = "rasp_to_transformer_integration_test_BOS"
29
+ _COMPILER_PAD = "rasp_to_transformer_integration_test_PAD"
30
+
31
+ # Force float32 precision on TPU, which otherwise defaults to float16.
32
+ jax.config.update("jax_default_matmul_precision", "float32")
33
+
34
+
35
+ class CompilerIntegrationTest(tests_common.VectorFnTestCase):
36
+
37
+ def assertSequenceEqualWhenExpectedIsNotNone(self, actual_seq, expected_seq):
38
+ for actual, expected in zip(actual_seq, expected_seq):
39
+ if expected is not None and actual != expected:
40
+ self.fail(f"{actual_seq} does not match (ignoring Nones) "
41
+ f"{expected_seq=}")
42
+
43
+ @parameterized.named_parameters(
44
+ dict(
45
+ testcase_name="map",
46
+ program=rasp.Map(lambda x: x + 1, rasp.tokens)),
47
+ dict(
48
+ testcase_name="sequence_map",
49
+ program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens,
50
+ rasp.indices)),
51
+ dict(
52
+ testcase_name="sequence_map_with_same_input",
53
+ program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens,
54
+ rasp.indices)),
55
+ dict(
56
+ testcase_name="select_aggregate",
57
+ program=rasp.Aggregate(
58
+ rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
59
+ rasp.Map(lambda x: 1, rasp.tokens))))
60
+ def test_rasp_program_and_transformer_produce_same_output(self, program):
61
+ vocab = {0, 1, 2}
62
+ max_seq_len = 3
63
+ assembled_model = compiling.compile_rasp_to_model(
64
+ program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS)
65
+
66
+ test_outputs = {}
67
+ rasp_outputs = {}
68
+ for val in vocab:
69
+ test_outputs[val] = assembled_model.apply([_COMPILER_BOS, val]).decoded[1]
70
+ rasp_outputs[val] = program([val])[0]
71
+
72
+ with self.subTest(val=0):
73
+ self.assertEqual(test_outputs[0], rasp_outputs[0])
74
+ with self.subTest(val=1):
75
+ self.assertEqual(test_outputs[1], rasp_outputs[1])
76
+ with self.subTest(val=2):
77
+ self.assertEqual(test_outputs[2], rasp_outputs[2])
78
+
79
+ @parameterized.named_parameters(*test_cases.TEST_CASES)
80
+ def test_compiled_models_produce_expected_output(self, program, vocab,
81
+ test_input, expected_output,
82
+ max_seq_len, **kwargs):
83
+ del kwargs
84
+ assembled_model = compiling.compile_rasp_to_model(
85
+ program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS)
86
+ test_output = assembled_model.apply([_COMPILER_BOS] + test_input)
87
+
88
+ if isinstance(expected_output[0], (int, float)):
89
+ np.testing.assert_allclose(
90
+ test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005)
91
+ else:
92
+ self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:],
93
+ expected_output)
94
+
95
+ @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES)
96
+ def test_compiled_causal_models_produce_expected_output(
97
+ self, program, vocab, test_input, expected_output, max_seq_len, **kwargs):
98
+ del kwargs
99
+ assembled_model = compiling.compile_rasp_to_model(
100
+ program,
101
+ vocab,
102
+ max_seq_len,
103
+ causal=True,
104
+ compiler_bos=_COMPILER_BOS,
105
+ compiler_pad=_COMPILER_PAD)
106
+ test_output = assembled_model.apply([_COMPILER_BOS] + test_input)
107
+
108
+ if isinstance(expected_output[0], (int, float)):
109
+ np.testing.assert_allclose(
110
+ test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005)
111
+ else:
112
+ self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:],
113
+ expected_output)
114
+
115
+ @parameterized.named_parameters(
116
+ dict(
117
+ testcase_name="reverse_1",
118
+ program=lib.make_reverse(rasp.tokens),
119
+ vocab={"a", "b", "c", "d"},
120
+ test_input=list("abcd"),
121
+ expected_output=list("dcba"),
122
+ max_seq_len=5),
123
+ dict(
124
+ testcase_name="reverse_2",
125
+ program=lib.make_reverse(rasp.tokens),
126
+ vocab={"a", "b", "c", "d"},
127
+ test_input=list("abc"),
128
+ expected_output=list("cba"),
129
+ max_seq_len=5),
130
+ dict(
131
+ testcase_name="reverse_3",
132
+ program=lib.make_reverse(rasp.tokens),
133
+ vocab={"a", "b", "c", "d"},
134
+ test_input=list("ad"),
135
+ expected_output=list("da"),
136
+ max_seq_len=5),
137
+ dict(
138
+ testcase_name="reverse_4",
139
+ program=lib.make_reverse(rasp.tokens),
140
+ vocab={"a", "b", "c", "d"},
141
+ test_input=["c"],
142
+ expected_output=["c"],
143
+ max_seq_len=5),
144
+ dict(
145
+ testcase_name="length_categorical_1",
146
+ program=rasp.categorical(lib.make_length()),
147
+ vocab={"a", "b", "c", "d"},
148
+ test_input=list("abc"),
149
+ expected_output=[3, 3, 3],
150
+ max_seq_len=5),
151
+ dict(
152
+ testcase_name="length_categorical_2",
153
+ program=rasp.categorical(lib.make_length()),
154
+ vocab={"a", "b", "c", "d"},
155
+ test_input=list("ad"),
156
+ expected_output=[2, 2],
157
+ max_seq_len=5),
158
+ dict(
159
+ testcase_name="length_categorical_3",
160
+ program=rasp.categorical(lib.make_length()),
161
+ vocab={"a", "b", "c", "d"},
162
+ test_input=["c"],
163
+ expected_output=[1],
164
+ max_seq_len=5),
165
+ dict(
166
+ testcase_name="length_numerical_1",
167
+ program=rasp.numerical(lib.make_length()),
168
+ vocab={"a", "b", "c", "d"},
169
+ test_input=list("abc"),
170
+ expected_output=[3, 3, 3],
171
+ max_seq_len=5),
172
+ dict(
173
+ testcase_name="length_numerical_2",
174
+ program=rasp.numerical(lib.make_length()),
175
+ vocab={"a", "b", "c", "d"},
176
+ test_input=list("ad"),
177
+ expected_output=[2, 2],
178
+ max_seq_len=5),
179
+ dict(
180
+ testcase_name="length_numerical_3",
181
+ program=rasp.numerical(lib.make_length()),
182
+ vocab={"a", "b", "c", "d"},
183
+ test_input=["c"],
184
+ expected_output=[1],
185
+ max_seq_len=5),
186
+ )
187
+ def test_compiled_models_produce_expected_output_with_padding(
188
+ self, program, vocab, test_input, expected_output, max_seq_len, **kwargs):
189
+ del kwargs
190
+ assembled_model = compiling.compile_rasp_to_model(
191
+ program,
192
+ vocab,
193
+ max_seq_len,
194
+ compiler_bos=_COMPILER_BOS,
195
+ compiler_pad=_COMPILER_PAD)
196
+
197
+ pad_len = (max_seq_len - len(test_input))
198
+ test_input = test_input + [_COMPILER_PAD] * pad_len
199
+ test_input = [_COMPILER_BOS] + test_input
200
+ test_output = assembled_model.apply(test_input)
201
+ output = test_output.decoded
202
+ output_len = len(output)
203
+ output_stripped = test_output.decoded[1:output_len - pad_len]
204
+
205
+ self.assertEqual(output[0], _COMPILER_BOS)
206
+ if isinstance(expected_output[0], (int, float)):
207
+ np.testing.assert_allclose(
208
+ output_stripped, expected_output, atol=1e-7, rtol=0.005)
209
+ else:
210
+ self.assertEqual(output_stripped, expected_output)
211
+
212
+
213
+ if __name__ == "__main__":
214
+ absltest.main()
compiler/test_cases.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """A set of RASP programs and input/output pairs used in integration tests."""
16
+
17
+ from tracr.compiler import lib
18
+ from tracr.rasp import rasp
19
+
20
+ UNIVERSAL_TEST_CASES = [
21
+ dict(
22
+ testcase_name="frac_prevs_1",
23
+ program=lib.make_frac_prevs(rasp.tokens == "l"),
24
+ vocab={"h", "e", "l", "o"},
25
+ test_input=list("hello"),
26
+ expected_output=[0.0, 0.0, 1 / 3, 1 / 2, 2 / 5],
27
+ max_seq_len=5),
28
+ dict(
29
+ testcase_name="frac_prevs_2",
30
+ program=lib.make_frac_prevs(rasp.tokens == "("),
31
+ vocab={"a", "b", "c", "(", ")"},
32
+ test_input=list("a()b(c))"),
33
+ expected_output=[0.0, 1 / 2, 1 / 3, 1 / 4, 2 / 5, 2 / 6, 2 / 7, 2 / 8],
34
+ max_seq_len=10),
35
+ dict(
36
+ testcase_name="frac_prevs_3",
37
+ program=lib.make_frac_prevs(rasp.tokens == ")"),
38
+ vocab={"a", "b", "c", "(", ")"},
39
+ test_input=list("a()b(c))"),
40
+ expected_output=[0.0, 0.0, 1 / 3, 1 / 4, 1 / 5, 1 / 6, 2 / 7, 3 / 8],
41
+ max_seq_len=10,
42
+ ),
43
+ dict(
44
+ testcase_name="shift_by_one",
45
+ program=lib.shift_by(1, rasp.tokens),
46
+ vocab={"a", "b", "c", "d"},
47
+ test_input=list("abcd"),
48
+ expected_output=[None, "a", "b", "c"],
49
+ max_seq_len=5,
50
+ ),
51
+ dict(
52
+ testcase_name="shift_by_two",
53
+ program=lib.shift_by(2, rasp.tokens),
54
+ vocab={"a", "b", "c", "d"},
55
+ test_input=list("abcd"),
56
+ expected_output=[None, None, "a", "b"],
57
+ max_seq_len=5,
58
+ ),
59
+ dict(
60
+ testcase_name="detect_pattern_a",
61
+ program=lib.detect_pattern(rasp.tokens, "a"),
62
+ vocab={"a", "b", "c", "d"},
63
+ test_input=list("bacd"),
64
+ expected_output=[False, True, False, False],
65
+ max_seq_len=5,
66
+ ),
67
+ dict(
68
+ testcase_name="detect_pattern_ab",
69
+ program=lib.detect_pattern(rasp.tokens, "ab"),
70
+ vocab={"a", "b"},
71
+ test_input=list("aaba"),
72
+ expected_output=[None, False, True, False],
73
+ max_seq_len=5,
74
+ ),
75
+ dict(
76
+ testcase_name="detect_pattern_ab_2",
77
+ program=lib.detect_pattern(rasp.tokens, "ab"),
78
+ vocab={"a", "b"},
79
+ test_input=list("abaa"),
80
+ expected_output=[None, True, False, False],
81
+ max_seq_len=5,
82
+ ),
83
+ dict(
84
+ testcase_name="detect_pattern_ab_3",
85
+ program=lib.detect_pattern(rasp.tokens, "ab"),
86
+ vocab={"a", "b"},
87
+ test_input=list("aaaa"),
88
+ expected_output=[None, False, False, False],
89
+ max_seq_len=5,
90
+ ),
91
+ dict(
92
+ testcase_name="detect_pattern_abc",
93
+ program=lib.detect_pattern(rasp.tokens, "abc"),
94
+ vocab={"a", "b", "c"},
95
+ test_input=list("abcabc"),
96
+ expected_output=[None, None, True, False, False, True],
97
+ max_seq_len=6,
98
+ ),
99
+ ]
100
+
101
+ TEST_CASES = UNIVERSAL_TEST_CASES + [
102
+ dict(
103
+ testcase_name="reverse_1",
104
+ program=lib.make_reverse(rasp.tokens),
105
+ vocab={"a", "b", "c", "d"},
106
+ test_input=list("abcd"),
107
+ expected_output=list("dcba"),
108
+ max_seq_len=5),
109
+ dict(
110
+ testcase_name="reverse_2",
111
+ program=lib.make_reverse(rasp.tokens),
112
+ vocab={"a", "b", "c", "d"},
113
+ test_input=list("abc"),
114
+ expected_output=list("cba"),
115
+ max_seq_len=5),
116
+ dict(
117
+ testcase_name="reverse_3",
118
+ program=lib.make_reverse(rasp.tokens),
119
+ vocab={"a", "b", "c", "d"},
120
+ test_input=list("ad"),
121
+ expected_output=list("da"),
122
+ max_seq_len=5),
123
+ dict(
124
+ testcase_name="reverse_4",
125
+ program=lib.make_reverse(rasp.tokens),
126
+ vocab={"a", "b", "c", "d"},
127
+ test_input=["c"],
128
+ expected_output=["c"],
129
+ max_seq_len=5),
130
+ dict(
131
+ testcase_name="length_categorical_1",
132
+ program=rasp.categorical(lib.make_length()),
133
+ vocab={"a", "b", "c", "d"},
134
+ test_input=list("abc"),
135
+ expected_output=[3, 3, 3],
136
+ max_seq_len=3),
137
+ dict(
138
+ testcase_name="length_categorical_2",
139
+ program=rasp.categorical(lib.make_length()),
140
+ vocab={"a", "b", "c", "d"},
141
+ test_input=list("ad"),
142
+ expected_output=[2, 2],
143
+ max_seq_len=3),
144
+ dict(
145
+ testcase_name="length_categorical_3",
146
+ program=rasp.categorical(lib.make_length()),
147
+ vocab={"a", "b", "c", "d"},
148
+ test_input=["c"],
149
+ expected_output=[1],
150
+ max_seq_len=3),
151
+ dict(
152
+ testcase_name="length_numerical_1",
153
+ program=rasp.numerical(lib.make_length()),
154
+ vocab={"a", "b", "c", "d"},
155
+ test_input=list("abc"),
156
+ expected_output=[3, 3, 3],
157
+ max_seq_len=3),
158
+ dict(
159
+ testcase_name="length_numerical_2",
160
+ program=rasp.numerical(lib.make_length()),
161
+ vocab={"a", "b", "c", "d"},
162
+ test_input=list("ad"),
163
+ expected_output=[2, 2],
164
+ max_seq_len=3),
165
+ dict(
166
+ testcase_name="length_numerical_3",
167
+ program=rasp.numerical(lib.make_length()),
168
+ vocab={"a", "b", "c", "d"},
169
+ test_input=["c"],
170
+ expected_output=[1],
171
+ max_seq_len=3),
172
+ dict(
173
+ testcase_name="pair_balance_1",
174
+ program=lib.make_pair_balance(rasp.tokens, "(", ")"),
175
+ vocab={"a", "b", "c", "(", ")"},
176
+ test_input=list("a()b(c))"),
177
+ expected_output=[0.0, 1 / 2, 0.0, 0.0, 1 / 5, 1 / 6, 0.0, -1 / 8],
178
+ max_seq_len=10),
179
+ dict(
180
+ testcase_name="shuffle_dyck2_1",
181
+ program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
182
+ vocab={"(", ")", "{", "}"},
183
+ test_input=list("({)}"),
184
+ expected_output=[1, 1, 1, 1],
185
+ max_seq_len=5),
186
+ dict(
187
+ testcase_name="shuffle_dyck2_2",
188
+ program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
189
+ vocab={"(", ")", "{", "}"},
190
+ test_input=list("(){)}"),
191
+ expected_output=[0, 0, 0, 0, 0],
192
+ max_seq_len=5),
193
+ dict(
194
+ testcase_name="shuffle_dyck2_3",
195
+ program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
196
+ vocab={"(", ")", "{", "}"},
197
+ test_input=list("{}("),
198
+ expected_output=[0, 0, 0],
199
+ max_seq_len=5),
200
+ dict(
201
+ testcase_name="shuffle_dyck3_1",
202
+ program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
203
+ vocab={"(", ")", "{", "}", "[", "]"},
204
+ test_input=list("({)[}]"),
205
+ expected_output=[1, 1, 1, 1, 1, 1],
206
+ max_seq_len=6),
207
+ dict(
208
+ testcase_name="shuffle_dyck3_2",
209
+ program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
210
+ vocab={"(", ")", "{", "}", "[", "]"},
211
+ test_input=list("(){)}"),
212
+ expected_output=[0, 0, 0, 0, 0],
213
+ max_seq_len=6),
214
+ dict(
215
+ testcase_name="shuffle_dyck3_3",
216
+ program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
217
+ vocab={"(", ")", "{", "}", "[", "]"},
218
+ test_input=list("{}[(]"),
219
+ expected_output=[0, 0, 0, 0, 0],
220
+ max_seq_len=6),
221
+ dict(
222
+ testcase_name="hist",
223
+ program=lib.make_hist(),
224
+ vocab={"a", "b", "c", "d"},
225
+ test_input=list("abac"),
226
+ expected_output=[2, 1, 2, 1],
227
+ max_seq_len=5,
228
+ ),
229
+ dict(
230
+ testcase_name="sort_unique_1",
231
+ program=lib.make_sort_unique(vals=rasp.tokens, keys=rasp.tokens),
232
+ vocab={1, 2, 3, 4},
233
+ test_input=[2, 4, 3, 1],
234
+ expected_output=[1, 2, 3, 4],
235
+ max_seq_len=5),
236
+ dict(
237
+ testcase_name="sort_unique_2",
238
+ program=lib.make_sort_unique(vals=rasp.tokens, keys=1 - rasp.indices),
239
+ vocab={"a", "b", "c", "d"},
240
+ test_input=list("abcd"),
241
+ expected_output=["d", "c", "b", "a"],
242
+ max_seq_len=5),
243
+ dict(
244
+ testcase_name="sort_1",
245
+ program=lib.make_sort(
246
+ vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1),
247
+ vocab={1, 2, 3, 4},
248
+ test_input=[2, 4, 3, 1],
249
+ expected_output=[1, 2, 3, 4],
250
+ max_seq_len=5),
251
+ dict(
252
+ testcase_name="sort_2",
253
+ program=lib.make_sort(
254
+ vals=rasp.tokens, keys=1 - rasp.indices, max_seq_len=5, min_key=1),
255
+ vocab={"a", "b", "c", "d"},
256
+ test_input=list("abcd"),
257
+ expected_output=["d", "c", "b", "a"],
258
+ max_seq_len=5),
259
+ dict(
260
+ testcase_name="sort_3",
261
+ program=lib.make_sort(
262
+ vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1),
263
+ vocab={1, 2, 3, 4},
264
+ test_input=[2, 4, 1, 2],
265
+ expected_output=[1, 2, 2, 4],
266
+ max_seq_len=5),
267
+ dict(
268
+ testcase_name="sort_freq",
269
+ program=lib.make_sort_freq(max_seq_len=5),
270
+ vocab={1, 2, 3, 4},
271
+ test_input=[2, 4, 2, 1],
272
+ expected_output=[2, 2, 4, 1],
273
+ max_seq_len=5),
274
+ dict(
275
+ testcase_name="make_count_less_freq_categorical_1",
276
+ program=lib.make_count_less_freq(n=2),
277
+ vocab={"a", "b", "c", "d"},
278
+ test_input=["a", "a", "a", "b", "b", "c"],
279
+ expected_output=[3, 3, 3, 3, 3, 3],
280
+ max_seq_len=6),
281
+ dict(
282
+ testcase_name="make_count_less_freq_categorical_2",
283
+ program=lib.make_count_less_freq(n=2),
284
+ vocab={"a", "b", "c", "d"},
285
+ test_input=["a", "a", "c", "b", "b", "c"],
286
+ expected_output=[6, 6, 6, 6, 6, 6],
287
+ max_seq_len=6),
288
+ dict(
289
+ testcase_name="make_count_less_freq_numerical_1",
290
+ program=rasp.numerical(lib.make_count_less_freq(n=2)),
291
+ vocab={"a", "b", "c", "d"},
292
+ test_input=["a", "a", "a", "b", "b", "c"],
293
+ expected_output=[3, 3, 3, 3, 3, 3],
294
+ max_seq_len=6),
295
+ dict(
296
+ testcase_name="make_count_less_freq_numerical_2",
297
+ program=rasp.numerical(lib.make_count_less_freq(n=2)),
298
+ vocab={"a", "b", "c", "d"},
299
+ test_input=["a", "a", "c", "b", "b", "c"],
300
+ expected_output=[6, 6, 6, 6, 6, 6],
301
+ max_seq_len=6),
302
+ dict(
303
+ testcase_name="make_count_1",
304
+ program=lib.make_count(rasp.tokens, "a"),
305
+ vocab={"a", "b", "c"},
306
+ test_input=["a", "a", "a", "b", "b", "c"],
307
+ expected_output=[3, 3, 3, 3, 3, 3],
308
+ max_seq_len=8,
309
+ ),
310
+ dict(
311
+ testcase_name="make_count_2",
312
+ program=lib.make_count(rasp.tokens, "a"),
313
+ vocab={"a", "b", "c"},
314
+ test_input=["c", "a", "b", "c"],
315
+ expected_output=[1, 1, 1, 1],
316
+ max_seq_len=8,
317
+ ),
318
+ dict(
319
+ testcase_name="make_count_3",
320
+ program=lib.make_count(rasp.tokens, "a"),
321
+ vocab={"a", "b", "c"},
322
+ test_input=["b", "b", "c"],
323
+ expected_output=[0, 0, 0],
324
+ max_seq_len=8,
325
+ ),
326
+ dict(
327
+ testcase_name="make_nary_sequencemap_1",
328
+ program=lib.make_nary_sequencemap(
329
+ lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices),
330
+ vocab={1, 2, 3},
331
+ test_input=[1, 2, 3],
332
+ expected_output=[2, 3, 4],
333
+ max_seq_len=5,
334
+ ),
335
+ dict(
336
+ testcase_name="make_nary_sequencemap_2",
337
+ program=lib.make_nary_sequencemap(
338
+ lambda x, y, z: x * y / z, rasp.indices, rasp.indices, rasp.tokens),
339
+ vocab={1, 2, 3},
340
+ test_input=[1, 2, 3],
341
+ expected_output=[0, 1 / 2, 4 / 3],
342
+ max_seq_len=3,
343
+ )
344
+ ]
345
+
346
+ # make_nary_sequencemap(f, *sops)
347
+
348
+ CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [
349
+ dict(
350
+ testcase_name="selector_width",
351
+ program=rasp.SelectorWidth(
352
+ rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)),
353
+ vocab={"a", "b", "c", "d"},
354
+ test_input=list("abcd"),
355
+ expected_output=[1, 2, 3, 4],
356
+ max_seq_len=5),
357
+ ]
craft/bases.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Vectors and bases."""
16
+
17
+ import dataclasses
18
+ from typing import Sequence, Union, Optional, Iterable
19
+
20
+ import numpy as np
21
+
22
+ Name = Union[int, str]
23
+ Value = Union[int, float, bool, str, tuple]
24
+
25
+
26
+ @dataclasses.dataclass(frozen=True)
27
+ class BasisDirection:
28
+ """Represents a basis direction (no magnitude) in a vector space.
29
+
30
+ Attributes:
31
+ name: a unique name for this direction.
32
+ value: used to hold a value one-hot-encoded by this direction. e.g.,
33
+ [BasisDirection("vs_1", True), BasisDirection("vs_1", False)] would be
34
+ basis directions of a subspace called "vs_1" which one-hot-encodes the
35
+ values True and False. If provided, considered part of the name for the
36
+ purpose of disambiguating directions.
37
+ """
38
+ name: Name
39
+ value: Optional[Value] = None
40
+
41
+ def __str__(self):
42
+ if self.value is None:
43
+ return str(self.name)
44
+ return f"{self.name}:{self.value}"
45
+
46
+ def __lt__(self, other: "BasisDirection") -> bool:
47
+ try:
48
+ return (self.name, self.value) < (other.name, other.value)
49
+ except TypeError:
50
+ return str(self) < str(other)
51
+
52
+
53
+ @dataclasses.dataclass
54
+ class VectorInBasis:
55
+ """A vector (or array of vectors) in a given basis.
56
+
57
+ When magnitudes are 1-d, this is a vector.
58
+ When magnitudes are (n+1)-d, this is an array of vectors,
59
+ where the -1th dimension is the basis dimension.
60
+ """
61
+ basis_directions: Sequence[BasisDirection]
62
+ magnitudes: np.ndarray
63
+
64
+ def __post_init__(self):
65
+ """Sort basis directions."""
66
+ if len(self.basis_directions) != self.magnitudes.shape[-1]:
67
+ raise ValueError(
68
+ "Last dimension of magnitudes must be the same as number "
69
+ f"of basis directions. Was {len(self.basis_directions)} "
70
+ f"and {self.magnitudes.shape[-1]}.")
71
+
72
+ sort_idx = np.argsort(self.basis_directions)
73
+ self.basis_directions = [self.basis_directions[i] for i in sort_idx]
74
+ self.magnitudes = np.take(self.magnitudes, sort_idx, -1)
75
+
76
+ def __add__(self, other: "VectorInBasis") -> "VectorInBasis":
77
+ if self.basis_directions != other.basis_directions:
78
+ raise TypeError(f"Adding incompatible bases: {self} + {other}")
79
+ magnitudes = self.magnitudes + other.magnitudes
80
+ return VectorInBasis(self.basis_directions, magnitudes)
81
+
82
+ def __radd__(self, other: "VectorInBasis") -> "VectorInBasis":
83
+ if self.basis_directions != other.basis_directions:
84
+ raise TypeError(f"Adding incompatible bases: {other} + {self}")
85
+ return self + other
86
+
87
+ def __sub__(self, other: "VectorInBasis") -> "VectorInBasis":
88
+ if self.basis_directions != other.basis_directions:
89
+ raise TypeError(f"Subtracting incompatible bases: {self} - {other}")
90
+ magnitudes = self.magnitudes - other.magnitudes
91
+ return VectorInBasis(self.basis_directions, magnitudes)
92
+
93
+ def __rsub__(self, other: "VectorInBasis") -> "VectorInBasis":
94
+ if self.basis_directions != other.basis_directions:
95
+ raise TypeError(f"Subtracting incompatible bases: {other} - {self}")
96
+ magnitudes = other.magnitudes - self.magnitudes
97
+ return VectorInBasis(self.basis_directions, magnitudes)
98
+
99
+ def __mul__(self, scalar: float) -> "VectorInBasis":
100
+ return VectorInBasis(self.basis_directions, self.magnitudes * scalar)
101
+
102
+ def __rmul__(self, scalar: float) -> "VectorInBasis":
103
+ return self * scalar
104
+
105
+ def __truediv__(self, scalar: float) -> "VectorInBasis":
106
+ return VectorInBasis(self.basis_directions, self.magnitudes / scalar)
107
+
108
+ def __neg__(self) -> "VectorInBasis":
109
+ return (-1) * self
110
+
111
+ def __eq__(self, other: "VectorInBasis") -> bool:
112
+ return ((self.basis_directions == other.basis_directions) and
113
+ (self.magnitudes.shape == other.magnitudes.shape) and
114
+ (np.all(self.magnitudes == other.magnitudes)))
115
+
116
+ @classmethod
117
+ def sum(cls, vectors: Sequence["VectorInBasis"]) -> "VectorInBasis":
118
+ return cls(vectors[0].basis_directions,
119
+ np.sum([x.magnitudes for x in vectors], axis=0))
120
+
121
+ @classmethod
122
+ def stack(cls,
123
+ vectors: Sequence["VectorInBasis"],
124
+ axis: int = 0) -> "VectorInBasis":
125
+ for v in vectors[1:]:
126
+ if v.basis_directions != vectors[0].basis_directions:
127
+ raise TypeError(f"Stacking incompatible bases: {vectors[0]} + {v}")
128
+ return cls(vectors[0].basis_directions,
129
+ np.stack([v.magnitudes for v in vectors], axis=axis))
130
+
131
+ def project(
132
+ self, basis: Union["VectorSpaceWithBasis", Sequence[BasisDirection]]
133
+ ) -> "VectorInBasis":
134
+ """Projects to the basis."""
135
+ if isinstance(basis, VectorSpaceWithBasis):
136
+ basis = basis.basis
137
+ components = []
138
+ for direction in basis:
139
+ if direction in self.basis_directions:
140
+ components.append(
141
+ self.magnitudes[..., self.basis_directions.index(direction)])
142
+ else:
143
+ components.append(np.zeros_like(self.magnitudes[..., 0]))
144
+ return VectorInBasis(list(basis), np.stack(components, axis=-1))
145
+
146
+
147
+ @dataclasses.dataclass
148
+ class VectorSpaceWithBasis:
149
+ """A vector subspace in a given basis."""
150
+ basis: Sequence[BasisDirection]
151
+
152
+ def __post_init__(self):
153
+ """Keep basis directions sorted."""
154
+ self.basis = sorted(self.basis)
155
+
156
+ @property
157
+ def num_dims(self) -> int:
158
+ return len(self.basis)
159
+
160
+ def __contains__(self, item: Union[VectorInBasis, BasisDirection]) -> bool:
161
+ if isinstance(item, BasisDirection):
162
+ return item in self.basis
163
+
164
+ return set(self.basis) == set(item.basis_directions)
165
+
166
+ def issubspace(self, other: "VectorSpaceWithBasis") -> bool:
167
+ return set(self.basis).issubset(set(other.basis))
168
+
169
+ def basis_vectors(self) -> Sequence[VectorInBasis]:
170
+ basis_vector_magnitudes = list(np.eye(self.num_dims))
171
+ return [VectorInBasis(self.basis, m) for m in basis_vector_magnitudes]
172
+
173
+ def vector_from_basis_direction(
174
+ self, basis_direction: BasisDirection) -> VectorInBasis:
175
+ i = self.basis.index(basis_direction)
176
+ return VectorInBasis(self.basis, np.eye(self.num_dims)[i])
177
+
178
+ def null_vector(self) -> VectorInBasis:
179
+ return VectorInBasis(self.basis, np.zeros(self.num_dims))
180
+
181
+ @classmethod
182
+ def from_names(cls, names: Sequence[Name]) -> "VectorSpaceWithBasis":
183
+ """Creates a VectorSpace from a list of names for its basis directions."""
184
+ return cls([BasisDirection(n) for n in names])
185
+
186
+ @classmethod
187
+ def from_values(
188
+ cls,
189
+ name: Name,
190
+ values: Iterable[Value],
191
+ ) -> "VectorSpaceWithBasis":
192
+ """Creates a VectorSpace from a list of values for its basis directions."""
193
+ return cls([BasisDirection(name, v) for v in values])
194
+
195
+
196
+ def direct_sum(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis:
197
+ """Create a direct sum of the vector spaces.
198
+
199
+ Assumes the basis elements of all input vector spaces are
200
+ orthogonal to each other. Maintains the order of the bases.
201
+
202
+ Args:
203
+ *vs: the vector spaces to sum.
204
+
205
+ Returns:
206
+ the combined vector space.
207
+
208
+ Raises:
209
+ Value error in case of overlapping bases.
210
+ """
211
+ # Take the union of all the bases:
212
+ total_basis = sum([v.basis for v in vs], [])
213
+
214
+ if len(total_basis) != len(set(total_basis)):
215
+ raise ValueError("Overlapping bases!")
216
+
217
+ return VectorSpaceWithBasis(total_basis)
218
+
219
+
220
+ def join_vector_spaces(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis:
221
+ """Joins a set of vector spaces allowing them to overlap.
222
+
223
+ Assumes the basis elements of all input vector spaces are
224
+ orthogonal to each other. Does not maintain the order of the bases but
225
+ sorts them.
226
+
227
+ Args:
228
+ *vs: the vector spaces to sum.
229
+
230
+ Returns:
231
+ the combined vector space.
232
+ """
233
+ # Take the union of all the bases:
234
+ total_basis = list(set().union(*[set(v.basis) for v in vs]))
235
+ total_basis = sorted(total_basis)
236
+ return VectorSpaceWithBasis(total_basis)
237
+
238
+
239
+ def ensure_dims(
240
+ vs: VectorSpaceWithBasis,
241
+ num_dims: int,
242
+ name: str = "vector space",
243
+ ) -> None:
244
+ """Raises ValueError if vs has the wrong number of dimensions."""
245
+ if vs.num_dims != num_dims:
246
+ raise ValueError(f"{name} must have {num_dims=}, "
247
+ f"but got {vs.num_dims}: {vs.basis}")
craft/bases_test.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for bases."""
16
+
17
+ from absl.testing import absltest
18
+ import numpy as np
19
+ from tracr.craft import bases
20
+ from tracr.craft import tests_common
21
+
22
+
23
+ class VectorInBasisTest(tests_common.VectorFnTestCase):
24
+
25
+ def test_shape_mismatch_raises_value_error(self):
26
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
27
+ regex = (r"^.*Last dimension of magnitudes must be the same as number of "
28
+ r"basis directions.*$")
29
+ with self.assertRaisesRegex(ValueError, regex):
30
+ bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
31
+ with self.assertRaisesRegex(ValueError, regex):
32
+ bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
33
+
34
+ def test_equal(self):
35
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
36
+ v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
37
+ v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
38
+ self.assertEqual(v1, v2)
39
+ self.assertEqual(v2, v1)
40
+ v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
41
+ v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
42
+ self.assertEqual(v3, v4)
43
+ self.assertEqual(v4, v3)
44
+ v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
45
+ v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1]))
46
+ self.assertNotEqual(v5, v6)
47
+ self.assertNotEqual(v6, v5)
48
+ v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
49
+ v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]]))
50
+ self.assertNotEqual(v7, v8)
51
+ self.assertNotEqual(v8, v7)
52
+ vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"])
53
+ v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
54
+ v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4]))
55
+ self.assertNotEqual(v9, v10)
56
+ self.assertNotEqual(v10, v9)
57
+
58
+ def test_dunders(self):
59
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
60
+ v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2]))
61
+ three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3]))
62
+ five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5]))
63
+ v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10]))
64
+ self.assertEqual(5 * v, v_times_5)
65
+ self.assertEqual(v * 5, v_times_5)
66
+ self.assertEqual(5.0 * v, v_times_5)
67
+ self.assertEqual(v * 5.0, v_times_5)
68
+ v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1]))
69
+ self.assertEqual(v / 2, v_by_2)
70
+ self.assertEqual(v / 2.0, v_by_2)
71
+ self.assertEqual(1 / 2 * v, v_by_2)
72
+ v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5]))
73
+ self.assertEqual(v + three, v_plus_3)
74
+ self.assertEqual(three + v, v_plus_3)
75
+ v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3]))
76
+ self.assertEqual(v - five, v_minus_5)
77
+ minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2]))
78
+ self.assertEqual(-v, minus_v)
79
+
80
+
81
+ class ProjectionTest(tests_common.VectorFnTestCase):
82
+
83
+ def test_direct_sum_produces_expected_result(self):
84
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
85
+ vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"])
86
+ vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"])
87
+ self.assertEqual(bases.direct_sum(vs1, vs2), vs3)
88
+
89
+ def test_join_vector_spaces_produces_expected_result(self):
90
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
91
+ vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"])
92
+ vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
93
+ self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3)
94
+
95
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
96
+ vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"])
97
+ vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
98
+ self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3)
99
+
100
+ def test_compare_vectors_with_differently_ordered_basis_vectors(self):
101
+ basis1 = ["a", "b", "c", "d"]
102
+ basis1 = [bases.BasisDirection(x) for x in basis1]
103
+ basis2 = ["b", "d", "a", "c"]
104
+ basis2 = [bases.BasisDirection(x) for x in basis2]
105
+ vs1 = bases.VectorSpaceWithBasis(basis1)
106
+ vs2 = bases.VectorSpaceWithBasis(basis2)
107
+ v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4]))
108
+ v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3]))
109
+ self.assertEqual(v1, v2)
110
+ self.assertEqual(v1 - v2, vs1.null_vector())
111
+ self.assertEqual(v1 - v2, vs2.null_vector())
112
+ self.assertEqual(v1 + v2, 2 * v2)
113
+ self.assertIn(v1, vs1)
114
+ self.assertIn(v1, vs2)
115
+ self.assertIn(v2, vs1)
116
+ self.assertIn(v2, vs2)
117
+
118
+ def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self):
119
+ basis1 = ["a", "b", "c", "d"]
120
+ basis1 = [bases.BasisDirection(x) for x in basis1]
121
+ basis2 = ["b", "d", "a", "c"]
122
+ basis2 = [bases.BasisDirection(x) for x in basis2]
123
+ vs1 = bases.VectorSpaceWithBasis(basis1)
124
+ vs2 = bases.VectorSpaceWithBasis(basis2)
125
+ v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]]))
126
+ v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]]))
127
+ null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()])
128
+ self.assertEqual(v1, v2)
129
+ self.assertEqual(v1 - v2, null_vec)
130
+ self.assertEqual(v1 + v2, 2 * v2)
131
+ self.assertIn(v1, vs1)
132
+ self.assertIn(v1, vs2)
133
+ self.assertIn(v2, vs1)
134
+ self.assertIn(v2, vs2)
135
+
136
+ def test_projection_to_larger_space(self):
137
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
138
+ vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
139
+ a1, b1 = vs1.basis_vectors()
140
+ a2, b2, _, _ = vs2.basis_vectors()
141
+
142
+ self.assertEqual(a1.project(vs2), a2)
143
+ self.assertEqual(b1.project(vs2), b2)
144
+
145
+ def test_projection_to_smaller_space(self):
146
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
147
+ vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
148
+ a1, b1, c1, d1 = vs1.basis_vectors()
149
+ a2, b2 = vs2.basis_vectors()
150
+
151
+ self.assertEqual(a1.project(vs2), a2)
152
+ self.assertEqual(b1.project(vs2), b2)
153
+ self.assertEqual(c1.project(vs2), vs2.null_vector())
154
+ self.assertEqual(d1.project(vs2), vs2.null_vector())
155
+
156
+
157
+ if __name__ == "__main__":
158
+ absltest.main()
craft/chamber/categorical_attn.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Attention head for categorical inputs."""
16
+
17
+ from typing import Optional, Protocol
18
+
19
+ from tracr.craft import bases
20
+ from tracr.craft import transformers
21
+ from tracr.craft import vectorspace_fns
22
+
23
+
24
+ class QueryKeyToAttnLogit(Protocol):
25
+
26
+ def __call__(self, query: bases.BasisDirection,
27
+ key: bases.BasisDirection) -> bool:
28
+ pass
29
+
30
+
31
+ def categorical_attn(
32
+ query_space: bases.VectorSpaceWithBasis,
33
+ key_space: bases.VectorSpaceWithBasis,
34
+ value_space: bases.VectorSpaceWithBasis,
35
+ output_space: bases.VectorSpaceWithBasis,
36
+ bos_space: bases.VectorSpaceWithBasis,
37
+ one_space: bases.VectorSpaceWithBasis,
38
+ attn_fn: QueryKeyToAttnLogit,
39
+ default_output: Optional[bases.VectorInBasis] = None,
40
+ causal: bool = False,
41
+ always_attend_to_bos: bool = False,
42
+ use_bos_for_default_output: bool = True,
43
+ softmax_coldness: float = 100.,
44
+ ) -> transformers.AttentionHead:
45
+ """Returns an attention head for categorical inputs.
46
+
47
+ Assumes the existence of a beginning of sequence token and attends to it
48
+ always with strength 0.5*softmax_coldness. This allows to implement an
49
+ arbitrary default value for rows in the attention pattern that are all-zero.
50
+
51
+ Attends to the BOS token if all other key-query pairs have zero attention.
52
+ Hence, the first value in the value sequence will be the default output for
53
+ such cases.
54
+
55
+ Args:
56
+ query_space: Vector space containing (categorical) query input.
57
+ key_space: Vector space containing (categorical) key input.
58
+ value_space: Vector space containing (numerical) value input.
59
+ output_space: Vector space which will contain (numerical) output.
60
+ bos_space: 1-d space used to identify the beginning of sequence token.
61
+ one_space: 1-d space which contains 1 at every position.
62
+ attn_fn: A selector function f(query, key) operating on the query/key basis
63
+ directions that defines the attention pattern.
64
+ default_output: Output to return if attention pattern is all zero.
65
+ causal: If True, use masked attention.
66
+ always_attend_to_bos: If True, always attend to the BOS token. If False,
67
+ only attend to BOS when attending to nothing else.
68
+ use_bos_for_default_output: If True, assume BOS is not in the value space
69
+ and output a default value when attending to BOS. If False, assume BOS is
70
+ in the value space, and map it to the output space like any other token.
71
+ softmax_coldness: The inverse temperature of the softmax. Default value is
72
+ high which makes the attention close to a hard maximum.
73
+ """
74
+ bases.ensure_dims(bos_space, num_dims=1, name="bos_space")
75
+ bases.ensure_dims(one_space, num_dims=1, name="one_space")
76
+ bos_direction = bos_space.basis[0]
77
+ one_direction = one_space.basis[0]
78
+
79
+ # Add bos direction to query, key, and value spaces in case it is missing
80
+ query_space = bases.join_vector_spaces(query_space, bos_space, one_space)
81
+ key_space = bases.join_vector_spaces(key_space, bos_space)
82
+ value_space = bases.join_vector_spaces(value_space, bos_space)
83
+
84
+ if always_attend_to_bos:
85
+ value_basis = value_space.basis
86
+ else:
87
+ value_basis = [v for v in value_space.basis if v != bos_direction]
88
+ assert len(value_basis) == output_space.num_dims
89
+ value_to_output = dict(zip(value_basis, output_space.basis))
90
+
91
+ if default_output is None:
92
+ default_output = output_space.null_vector()
93
+ assert default_output in output_space
94
+
95
+ def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float:
96
+
97
+ # We want to enforce the following property on our attention patterns:
98
+ # - if nothing else is attended to, attend to the BOS token.
99
+ # - otherwise, don't attend to the BOS token.
100
+ #
101
+ # We assume that the BOS position always only contains the vector bos + one,
102
+ # and that any other position has bos coefficient 0.
103
+ #
104
+ # We do this as follows:
105
+ # Let Q and K be subspaces of V containing the query and key vectors,
106
+ # both disjoint with the BOS space {bos} or the one space {one}.
107
+ # Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> ℝ,
108
+ # s.t. W_QK(q, k) = 0 whenever either q or k are bos or one.
109
+ #
110
+ # Then define W_new: V x V -> ℝ st:
111
+ # W_new(one, bos) = 0.5, otherwise 0.
112
+ #
113
+ # Now set W_QK' = W_QK + W_new.
114
+ #
115
+ # To evaluate the attention to the BOS position:
116
+ # W_QK'(q, bos + one)
117
+ # = W_QK'(q, bos) + W_QK'(q, one)
118
+ # = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one)
119
+ # = 0 + 0 + W_new(q, bos) + W_new(q, one)
120
+ # = W_new(q, bos) + W_new(q, one)
121
+ # = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q'
122
+ # = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one)
123
+ # = 0 + 0.5 + 0 + 0
124
+ # = 0.5
125
+ #
126
+ # To evaluate the attention to a non-BOS position:
127
+ # W_QK'(0 * bos + q, 0 * bos + k) # s.t. q ∈ Q+{one}, k ∈ K+{one}
128
+ # = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k)
129
+ # = W_QK'(q, 0*bos + k)
130
+ # = 0*W_QK'(q, bos) + W_QK'(q, k)
131
+ # = W_QK'(q, k)
132
+ # = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos.
133
+ # = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one.
134
+ #
135
+ # Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax
136
+ # coldness will give us the desired property. QED
137
+ #
138
+ # The following implements this idea.
139
+ # By replacing 0.5 with 1, we can instead enforce a different property: that
140
+ # the BOS token is always attended to in addition to whatever else.
141
+
142
+ if key == bos_direction and query == one_direction:
143
+ c = 1. if always_attend_to_bos else 0.5
144
+ return c * softmax_coldness
145
+ elif {key, query}.intersection({one_direction, bos_direction}):
146
+ return 0
147
+
148
+ return softmax_coldness * attn_fn(query, key)
149
+
150
+ w_qk = vectorspace_fns.ScalarBilinear.from_action(
151
+ query_space,
152
+ key_space,
153
+ qk_fun,
154
+ )
155
+
156
+ def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis:
157
+ if use_bos_for_default_output and input_dir == bos_direction:
158
+ return default_output
159
+ return output_space.vector_from_basis_direction(value_to_output[input_dir])
160
+
161
+ w_ov = vectorspace_fns.Linear.from_action(
162
+ value_space,
163
+ output_space,
164
+ ov_fun,
165
+ )
166
+
167
+ return transformers.AttentionHead(w_qk, w_ov, causal=causal)
craft/chamber/categorical_attn_test.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for chamber.categorical_attn."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+ from tracr.craft import bases
21
+ from tracr.craft import tests_common
22
+ from tracr.craft.chamber import categorical_attn
23
+
24
+
25
+ class CategoricalAttnTest(tests_common.VectorFnTestCase):
26
+
27
+ @parameterized.parameters([
28
+ dict(causal=False, input_seq=[1, 2, 3, 4, 5], result_seq=[3, 3, 3, 3, 3]),
29
+ dict(
30
+ causal=True,
31
+ input_seq=[1, 2, 3, 4, 5],
32
+ result_seq=[1, 1.5, 2, 2.5, 3]),
33
+ dict(causal=False, input_seq=[10], result_seq=[10]),
34
+ dict(causal=True, input_seq=[10], result_seq=[10]),
35
+ dict(causal=False, input_seq=[-1, 0, 1], result_seq=[0, 0, 0]),
36
+ dict(causal=True, input_seq=[-1, 0, 1], result_seq=[-1, -0.5, 0]),
37
+ ])
38
+ def test_categorical_attn_can_implement_select_all(self, causal, input_seq,
39
+ result_seq):
40
+ vocab = range(-20, 20)
41
+ input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
42
+
43
+ output_dir = bases.BasisDirection("output")
44
+ output_space = bases.VectorSpaceWithBasis([output_dir])
45
+ output_vec = output_space.vector_from_basis_direction(output_dir)
46
+
47
+ bos_dir = bases.BasisDirection("bos_dimension")
48
+ bos_space = bases.VectorSpaceWithBasis([bos_dir])
49
+
50
+ one_dir = bases.BasisDirection("one")
51
+ one_space = bases.VectorSpaceWithBasis([one_dir])
52
+
53
+ value_dir = bases.BasisDirection("value")
54
+ value_space = bases.VectorSpaceWithBasis([value_dir])
55
+
56
+ input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
57
+ value_space = bases.join_vector_spaces(value_space, bos_space)
58
+ residual_space = bases.join_vector_spaces(input_space, value_space,
59
+ output_space)
60
+ one_vec = residual_space.vector_from_basis_direction(one_dir)
61
+ bos_vec = residual_space.vector_from_basis_direction(bos_dir)
62
+ value_vec = residual_space.vector_from_basis_direction(value_dir)
63
+
64
+ attn = categorical_attn.categorical_attn(
65
+ key_space=input_space,
66
+ query_space=input_space,
67
+ value_space=value_space,
68
+ output_space=output_space,
69
+ bos_space=bos_space,
70
+ one_space=one_space,
71
+ attn_fn=lambda x, y: True,
72
+ causal=causal)
73
+
74
+ test_inputs = [bos_vec + one_vec]
75
+ for x in input_seq:
76
+ test_inputs.append(
77
+ residual_space.vector_from_basis_direction(
78
+ bases.BasisDirection("input", x)) + x * value_vec)
79
+ test_inputs = bases.VectorInBasis.stack(test_inputs)
80
+
81
+ # Expect the average of all (previous) tokens
82
+ expected_results = [x * output_vec for x in result_seq]
83
+ expected_results = bases.VectorInBasis.stack(expected_results)
84
+
85
+ test_outputs = attn.apply(test_inputs).project(output_space)
86
+
87
+ self.assertVectorAllClose(
88
+ tests_common.strip_bos_token(test_outputs), expected_results)
89
+
90
+ @parameterized.parameters([
91
+ dict(causal=False, input_seq=[1, 2, 3, 4, 5], default=0),
92
+ dict(causal=True, input_seq=[1, 2, 3, 4, 5], default=1),
93
+ dict(causal=False, input_seq=[10], default=2),
94
+ dict(causal=True, input_seq=[10], default=-3),
95
+ dict(causal=False, input_seq=[-1, 0, 1], default=-2),
96
+ dict(causal=True, input_seq=[-1, 0, 1], default=-1),
97
+ ])
98
+ def test_categorical_attn_can_implement_select_none(self, causal, input_seq,
99
+ default):
100
+ vocab = range(-20, 20)
101
+ input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
102
+
103
+ output_dir = bases.BasisDirection("output")
104
+ output_space = bases.VectorSpaceWithBasis([output_dir])
105
+ default_vec = default * output_space.vector_from_basis_direction(output_dir)
106
+
107
+ bos_dir = bases.BasisDirection("bos_dimension")
108
+ bos_space = bases.VectorSpaceWithBasis([bos_dir])
109
+
110
+ one_dir = bases.BasisDirection("one")
111
+ one_space = bases.VectorSpaceWithBasis([one_dir])
112
+
113
+ value_dir = bases.BasisDirection("value")
114
+ value_space = bases.VectorSpaceWithBasis([value_dir])
115
+
116
+ input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
117
+ value_space = bases.join_vector_spaces(value_space, bos_space)
118
+ residual_space = bases.join_vector_spaces(input_space, value_space,
119
+ output_space)
120
+ value_vec = residual_space.vector_from_basis_direction(value_dir)
121
+ bos_vec = residual_space.vector_from_basis_direction(bos_dir)
122
+ one_vec = residual_space.vector_from_basis_direction(one_dir)
123
+
124
+ attn = categorical_attn.categorical_attn(
125
+ key_space=input_space,
126
+ query_space=input_space,
127
+ value_space=value_space,
128
+ output_space=output_space,
129
+ bos_space=bos_space,
130
+ one_space=one_space,
131
+ attn_fn=lambda x, y: False,
132
+ default_output=default_vec,
133
+ causal=causal,
134
+ always_attend_to_bos=False,
135
+ use_bos_for_default_output=True)
136
+
137
+ def make_input(x):
138
+ return (one_vec + x * value_vec +
139
+ residual_space.vector_from_basis_direction(
140
+ bases.BasisDirection("input", x)))
141
+
142
+ test_inputs = bases.VectorInBasis.stack([bos_vec + one_vec] +
143
+ [make_input(x) for x in input_seq])
144
+
145
+ # Expect the default value
146
+ expected_results = [default_vec for x in input_seq]
147
+ expected_results = bases.VectorInBasis.stack(expected_results)
148
+
149
+ test_outputs = attn.apply(test_inputs).project(output_space)
150
+
151
+ self.assertVectorAllClose(
152
+ tests_common.strip_bos_token(test_outputs), expected_results)
153
+
154
+ @parameterized.parameters([
155
+ dict(num_counts=5, input_seq=[1, 4, 3, 2], n=1, result=[4, 3, 2, 1]),
156
+ dict(num_counts=10, input_seq=[5, 8, 9, 2], n=3, result=[2, 5, 8, 9])
157
+ ])
158
+ def test_categorical_attn_can_implement_shift_by_n(self, num_counts,
159
+ input_seq, n, result):
160
+ query_prefix = "prefix1"
161
+ key_prefix = "prefix2"
162
+ agg_input_prefix = "prefix3"
163
+ output_prefix = "prefix4"
164
+
165
+ bos_direction = bases.BasisDirection("bos")
166
+ one_direction = bases.BasisDirection("one")
167
+ query_space = bases.VectorSpaceWithBasis.from_values(
168
+ query_prefix, range(num_counts))
169
+ key_space = bases.VectorSpaceWithBasis.from_values(key_prefix,
170
+ range(num_counts))
171
+ bos_space = bases.VectorSpaceWithBasis([bos_direction])
172
+ one_space = bases.VectorSpaceWithBasis([one_direction])
173
+ key_space = bases.join_vector_spaces(key_space, bos_space)
174
+
175
+ agg_input_space = bases.VectorSpaceWithBasis.from_values(
176
+ agg_input_prefix, range(num_counts))
177
+ agg_input_space = bases.join_vector_spaces(agg_input_space, bos_space)
178
+ output_space = bases.VectorSpaceWithBasis.from_values(
179
+ output_prefix, range(num_counts))
180
+
181
+ attn = categorical_attn.categorical_attn(
182
+ query_space=query_space,
183
+ key_space=key_space,
184
+ value_space=agg_input_space,
185
+ output_space=output_space,
186
+ bos_space=bos_space,
187
+ one_space=one_space,
188
+ attn_fn=lambda q, k: q.value == k.value,
189
+ default_output=None,
190
+ always_attend_to_bos=False,
191
+ use_bos_for_default_output=True,
192
+ causal=False)
193
+
194
+ residual_space = bases.join_vector_spaces(key_space, query_space,
195
+ agg_input_space, output_space,
196
+ one_space)
197
+
198
+ seq_len = len(input_seq)
199
+ query_seq = np.arange(n, seq_len + n) % seq_len
200
+ key_seq = np.arange(seq_len)
201
+
202
+ bos_vec = residual_space.vector_from_basis_direction(bos_direction)
203
+ one_vec = residual_space.vector_from_basis_direction(one_direction)
204
+
205
+ test_inputs = [bos_vec + one_vec]
206
+ expected_results = []
207
+ for i in range(seq_len):
208
+ test_inputs.append(
209
+ residual_space.vector_from_basis_direction(
210
+ bases.BasisDirection(query_prefix, query_seq[i])) +
211
+ residual_space.vector_from_basis_direction(
212
+ bases.BasisDirection(key_prefix, key_seq[i])) +
213
+ residual_space.vector_from_basis_direction(
214
+ bases.BasisDirection(agg_input_prefix, input_seq[i])))
215
+ expected_results.append(
216
+ residual_space.vector_from_basis_direction(
217
+ bases.BasisDirection(output_prefix, result[i])))
218
+
219
+ test_inputs = bases.VectorInBasis.stack(test_inputs)
220
+ expected_results = bases.VectorInBasis.stack(expected_results)
221
+
222
+ test_outputs = attn.apply(test_inputs)
223
+
224
+ self.assertVectorAllClose(
225
+ tests_common.strip_bos_token(test_outputs), expected_results)
226
+
227
+
228
+ if __name__ == "__main__":
229
+ absltest.main()
craft/chamber/categorical_mlp.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MLP to compute basic linear functions of one-hot encoded integers."""
16
+
17
+ from typing import Callable
18
+
19
+ import numpy as np
20
+
21
+ from tracr.craft import bases
22
+ from tracr.craft import transformers
23
+ from tracr.craft import vectorspace_fns
24
+
25
+ _ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"])
26
+
27
+
28
+ def map_categorical_mlp(
29
+ input_space: bases.VectorSpaceWithBasis,
30
+ output_space: bases.VectorSpaceWithBasis,
31
+ operation: Callable[[bases.BasisDirection], bases.BasisDirection],
32
+ ) -> transformers.MLP:
33
+ """Returns an MLP that encodes any categorical function of a single variable f(x).
34
+
35
+ The hidden layer is the identity and output combines this with a lookup table
36
+ output_k = sum(f(i)*input_i for all i in input space)
37
+
38
+ Args:
39
+ input_space: space containing the input x.
40
+ output_space: space containing possible outputs.
41
+ operation: A function operating on basis directions.
42
+ """
43
+
44
+ def operation_fn(direction):
45
+ if direction in input_space:
46
+ output_direction = operation(direction)
47
+ if output_direction in output_space:
48
+ return output_space.vector_from_basis_direction(output_direction)
49
+ return output_space.null_vector()
50
+
51
+ first_layer = vectorspace_fns.Linear.from_action(input_space, output_space,
52
+ operation_fn)
53
+
54
+ second_layer = vectorspace_fns.project(output_space, output_space)
55
+
56
+ return transformers.MLP(first_layer, second_layer)
57
+
58
+
59
+ def map_categorical_to_numerical_mlp(
60
+ input_space: bases.VectorSpaceWithBasis,
61
+ output_space: bases.VectorSpaceWithBasis,
62
+ operation: Callable[[bases.Value], float],
63
+ ) -> transformers.MLP:
64
+ """Returns an MLP to compute f(x) from a categorical to a numerical variable.
65
+
66
+ The hidden layer is the identity and output combines this with a lookup table
67
+ output = sum(f(i)*input_i for all i in input space)
68
+
69
+ Args:
70
+ input_space: Vector space containing the input x.
71
+ output_space: Vector space to write the numerical output to.
72
+ operation: A function operating on basis directions.
73
+ """
74
+ bases.ensure_dims(output_space, num_dims=1, name="output_space")
75
+ out_vec = output_space.vector_from_basis_direction(output_space.basis[0])
76
+
77
+ def operation_fn(direction):
78
+ if direction in input_space:
79
+ return operation(direction.value) * out_vec
80
+ return output_space.null_vector()
81
+
82
+ first_layer = vectorspace_fns.Linear.from_action(input_space, output_space,
83
+ operation_fn)
84
+
85
+ second_layer = vectorspace_fns.project(output_space, output_space)
86
+
87
+ return transformers.MLP(first_layer, second_layer)
88
+
89
+
90
+ def sequence_map_categorical_mlp(
91
+ input1_space: bases.VectorSpaceWithBasis,
92
+ input2_space: bases.VectorSpaceWithBasis,
93
+ output_space: bases.VectorSpaceWithBasis,
94
+ operation: Callable[[bases.BasisDirection, bases.BasisDirection],
95
+ bases.BasisDirection],
96
+ one_space: bases.VectorSpaceWithBasis = _ONE_SPACE,
97
+ hidden_name: bases.Name = "__hidden__",
98
+ ) -> transformers.MLP:
99
+ """Returns an MLP that encodes a categorical function of two variables f(x, y).
100
+
101
+ The hidden layer of the MLP computes the logical and of all input directions
102
+ hidden_i_j = ReLU(x_i+x_j-1)
103
+
104
+ And the output combines this with a lookup table
105
+ output_k = sum(f(i, j)*hidden_i_j for all i,j in input space)
106
+
107
+ Args:
108
+ input1_space: Vector space containing the input x.
109
+ input2_space: Vector space containing the input y.
110
+ output_space: Vector space to write outputs to.
111
+ operation: A function operating on basis directions.
112
+ one_space: a reserved 1-d space that always contains a 1.
113
+ hidden_name: Name for hidden dimensions.
114
+ """
115
+ bases.ensure_dims(one_space, num_dims=1, name="one_space")
116
+
117
+ if not set(input1_space.basis).isdisjoint(input2_space.basis):
118
+ raise ValueError("Input spaces to a SequenceMap must be disjoint. "
119
+ "If input spaces are the same, use Map instead!")
120
+
121
+ input_space = bases.direct_sum(input1_space, input2_space, one_space)
122
+
123
+ def to_hidden(x, y):
124
+ return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value))
125
+
126
+ def from_hidden(h):
127
+ x_name, x_value, y_name, y_value = h.value
128
+ x_dir = bases.BasisDirection(x_name, x_value)
129
+ y_dir = bases.BasisDirection(y_name, y_value)
130
+ return x_dir, y_dir
131
+
132
+ hidden_dir = []
133
+ for dir1 in input1_space.basis:
134
+ for dir2 in input2_space.basis:
135
+ hidden_dir.append(to_hidden(dir1, dir2))
136
+ hidden_space = bases.VectorSpaceWithBasis(hidden_dir)
137
+
138
+ def logical_and(direction):
139
+ if direction in one_space:
140
+ out = bases.VectorInBasis(hidden_space.basis,
141
+ -np.ones(hidden_space.num_dims))
142
+ elif direction in input1_space:
143
+ dir1 = direction
144
+ out = hidden_space.null_vector()
145
+ for dir2 in input2_space.basis:
146
+ out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2))
147
+ else:
148
+ dir2 = direction
149
+ out = hidden_space.null_vector()
150
+ for dir1 in input1_space.basis:
151
+ out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2))
152
+ return out
153
+
154
+ first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space,
155
+ logical_and)
156
+
157
+ def operation_fn(direction):
158
+ dir1, dir2 = from_hidden(direction)
159
+ output_direction = operation(dir1, dir2)
160
+ if output_direction in output_space:
161
+ return output_space.vector_from_basis_direction(output_direction)
162
+ else:
163
+ return output_space.null_vector()
164
+
165
+ second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space,
166
+ operation_fn)
167
+
168
+ return transformers.MLP(first_layer, second_layer)
craft/chamber/categorical_mlp_test.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for chamber.categorical_mlp."""
16
+
17
+ import math
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+
21
+ from tracr.craft import bases
22
+ from tracr.craft import tests_common
23
+ from tracr.craft.chamber import categorical_mlp
24
+
25
+
26
+ class CategoricalInputMlpTest(tests_common.VectorFnTestCase):
27
+
28
+ @parameterized.parameters([
29
+ dict(num_counts=4, x=1, y=2, fun=lambda x, y: x + y, result=3),
30
+ dict(num_counts=4, x=1, y=0, fun=lambda x, y: x + y + 1, result=2),
31
+ dict(num_counts=5, x=2, y=1, fun=math.pow, result=2),
32
+ dict(num_counts=5, x=2, y=2, fun=math.pow, result=4),
33
+ ])
34
+ def test_seq_map_categorical_mlp_produces_expected_outcome(
35
+ self, num_counts, x, y, fun, result):
36
+ input1_name = "in1"
37
+ input2_name = "in2"
38
+ output_name = "out"
39
+ one_name = "one_dimension"
40
+
41
+ in1_space = bases.VectorSpaceWithBasis.from_values(input1_name,
42
+ range(num_counts + 1))
43
+ in2_space = bases.VectorSpaceWithBasis.from_values(input2_name,
44
+ range(num_counts + 1))
45
+ out_space = bases.VectorSpaceWithBasis.from_values(output_name,
46
+ range(num_counts + 1))
47
+
48
+ def operation(in1, in2):
49
+ out_val = fun(int(in1.value), int(in2.value))
50
+ return bases.BasisDirection(output_name, out_val)
51
+
52
+ mlp = categorical_mlp.sequence_map_categorical_mlp(
53
+ input1_space=in1_space,
54
+ input2_space=in2_space,
55
+ output_space=out_space,
56
+ operation=operation,
57
+ one_space=bases.VectorSpaceWithBasis.from_names([one_name]))
58
+
59
+ test_inputs = (
60
+ mlp.residual_space.vector_from_basis_direction(
61
+ bases.BasisDirection(one_name)) +
62
+ mlp.residual_space.vector_from_basis_direction(
63
+ bases.BasisDirection(input1_name, x)) +
64
+ mlp.residual_space.vector_from_basis_direction(
65
+ bases.BasisDirection(input2_name, y)))
66
+
67
+ expected_results = mlp.residual_space.vector_from_basis_direction(
68
+ bases.BasisDirection(output_name, result))
69
+
70
+ test_outputs = mlp.apply(test_inputs)
71
+
72
+ self.assertVectorAllClose(test_outputs, expected_results)
73
+
74
+ def test_seq_map_categorical_mlp_raises_error_with_overlapping_inputs(self):
75
+ input_name = "in"
76
+ output_name = "out"
77
+ one_name = "one_dimension"
78
+
79
+ in1_space = bases.VectorSpaceWithBasis.from_values(input_name, range(5))
80
+ in2_space = bases.VectorSpaceWithBasis.from_values(input_name, range(3, 10))
81
+ out_space = bases.VectorSpaceWithBasis.from_values(output_name, range(5))
82
+
83
+ with self.assertRaisesRegex(
84
+ ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"):
85
+ categorical_mlp.sequence_map_categorical_mlp(
86
+ input1_space=in1_space,
87
+ input2_space=in1_space,
88
+ output_space=out_space,
89
+ operation=lambda x, y: bases.BasisDirection(output_name, 0),
90
+ one_space=bases.VectorSpaceWithBasis.from_names([one_name]))
91
+
92
+ with self.assertRaisesRegex(
93
+ ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"):
94
+ categorical_mlp.sequence_map_categorical_mlp(
95
+ input1_space=in1_space,
96
+ input2_space=in2_space,
97
+ output_space=out_space,
98
+ operation=lambda x, y: bases.BasisDirection(output_name, 0),
99
+ one_space=bases.VectorSpaceWithBasis.from_names([one_name]))
100
+
101
+ @parameterized.parameters([
102
+ dict(num_counts=5, x=2, fun=lambda x: x, result=2),
103
+ dict(num_counts=5, x=2, fun=lambda x: math.pow(x, int(2)), result=4),
104
+ dict(num_counts=5, x=-2, fun=lambda x: math.pow(x, int(2)), result=4),
105
+ dict(num_counts=5, x=-1, fun=lambda x: math.pow(x, int(3)), result=-1),
106
+ ])
107
+ def test_map_categorical_mlp_produces_expected_outcome_computing_powers(
108
+ self, num_counts, x, fun, result):
109
+ input_name = "in"
110
+ output_name = "out"
111
+
112
+ in_space = bases.VectorSpaceWithBasis.from_values(
113
+ input_name, range(-num_counts, num_counts + 1))
114
+ out_space = bases.VectorSpaceWithBasis.from_values(
115
+ output_name, range(-num_counts, num_counts + 1))
116
+
117
+ def operation(direction):
118
+ out_val = fun(int(direction.value))
119
+ return bases.BasisDirection(output_name, out_val)
120
+
121
+ mlp = categorical_mlp.map_categorical_mlp(
122
+ input_space=in_space, output_space=out_space, operation=operation)
123
+
124
+ test_inputs = mlp.residual_space.vector_from_basis_direction(
125
+ bases.BasisDirection(input_name, x))
126
+
127
+ expected_results = mlp.residual_space.vector_from_basis_direction(
128
+ bases.BasisDirection(output_name, result))
129
+
130
+ test_outputs = mlp.apply(test_inputs)
131
+
132
+ self.assertVectorAllClose(test_outputs, expected_results)
133
+
134
+ @parameterized.parameters([
135
+ dict(x=2, fun=lambda x: x, result=2),
136
+ dict(x=2, fun=lambda x: math.pow(x, int(2)), result=4),
137
+ dict(x=1, fun=lambda x: 1 / (x + 1), result=0.5),
138
+ dict(x=3, fun=lambda x: 1 / (x + 1), result=0.25),
139
+ ])
140
+ def test_map_categorical_to_numerical_mlp_produces_expected_outcome(
141
+ self, x, fun, result):
142
+
143
+ in_space = bases.VectorSpaceWithBasis.from_values("in", range(6))
144
+ out_space = bases.VectorSpaceWithBasis.from_names(["out"])
145
+
146
+ mlp = categorical_mlp.map_categorical_to_numerical_mlp(
147
+ input_space=in_space,
148
+ output_space=out_space,
149
+ operation=fun,
150
+ )
151
+
152
+ test_inputs = mlp.residual_space.vector_from_basis_direction(
153
+ bases.BasisDirection("in", x))
154
+
155
+ expected_results = result * mlp.residual_space.vector_from_basis_direction(
156
+ bases.BasisDirection("out"))
157
+
158
+ test_outputs = mlp.apply(test_inputs)
159
+
160
+ self.assertVectorAllClose(test_outputs, expected_results)
161
+
162
+
163
+ if __name__ == "__main__":
164
+ absltest.main()
craft/chamber/numerical_mlp.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MLPs to compute arbitrary numerical functions by discretising."""
16
+
17
+ import dataclasses
18
+
19
+ from typing import Callable, Iterable
20
+
21
+ from tracr.craft import bases
22
+ from tracr.craft import transformers
23
+ from tracr.craft import vectorspace_fns
24
+ from tracr.utils import errors
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class DiscretisingLayerMaterials:
29
+ """Provides components for a hidden layer that discretises the input.
30
+
31
+ Attributes:
32
+ action: Function acting on basis directions that defines the computation.
33
+ hidden_space: Vector space of the hidden representation of the layer.
34
+ output_values: Set of output values that correspond to the discretisation.
35
+ """
36
+ action: Callable[[bases.BasisDirection], bases.VectorInBasis]
37
+ hidden_space: bases.VectorSpaceWithBasis
38
+ output_values: list[float]
39
+
40
+
41
+ def _get_discretising_layer(input_value_set: Iterable[float],
42
+ f: Callable[[float],
43
+ float], hidden_name: bases.Name,
44
+ one_direction: bases.BasisDirection,
45
+ large_number: float) -> DiscretisingLayerMaterials:
46
+ """Creates a hidden layer that discretises the input of f(x) into a value set.
47
+
48
+ The input is split up into a distinct region around each value in
49
+ `input_value_set`:
50
+
51
+ elements of value set: v0 | v1 | v2 | v3 | v4 | ...
52
+ thresholds: t0 t1 t2 t3 t4
53
+
54
+ The hidden layer has two activations per threshold:
55
+ hidden_k_1 = ReLU(L * (x - threshold[k]) + 1)
56
+ hidden_k_2 = ReLU(L * (x - threshold[k]))
57
+
58
+ Note that hidden_k_1 - hidden_k_2 is:
59
+ 1 if x >= threshold[k] + 1/L
60
+ 0 if x <= threshold[k]
61
+ between 0 and 1 if threshold[k] < x < threshold[k] + 1/L
62
+
63
+ So as long as we choose L a big enough number, we have
64
+ hidden_k_1 - hidden_k_2 = 1 if x >= threshold[k].
65
+ i.e. we know in which region the input value is.
66
+
67
+ Args:
68
+ input_value_set: Set of discrete input values.
69
+ f: Function to approximate.
70
+ hidden_name: Name for hidden dimensions.
71
+ one_direction: Auxiliary dimension that must contain 1 in the input.
72
+ large_number: Large number L that determines accuracy of the computation.
73
+
74
+ Returns:
75
+ DiscretisingLayerMaterials containing all components for the layer.
76
+ """
77
+ output_values, sorted_values = [], []
78
+ for x in sorted(input_value_set):
79
+ res = errors.ignoring_arithmetic_errors(f)(x)
80
+ if res is not None:
81
+ output_values.append(res)
82
+ sorted_values.append(x)
83
+
84
+ num_vals = len(sorted_values)
85
+ value_thresholds = [
86
+ (sorted_values[i] + sorted_values[i + 1]) / 2 for i in range(num_vals - 1)
87
+ ]
88
+
89
+ hidden_directions = [bases.BasisDirection(f"{hidden_name}start")]
90
+ for k in range(1, num_vals):
91
+ dir0 = bases.BasisDirection(hidden_name, (k, 0))
92
+ dir1 = bases.BasisDirection(hidden_name, (k, 1))
93
+ hidden_directions.extend([dir0, dir1])
94
+ hidden_space = bases.VectorSpaceWithBasis(hidden_directions)
95
+
96
+ def action(direction: bases.BasisDirection) -> bases.VectorInBasis:
97
+ # hidden_k_0 = ReLU(L * (x - threshold[k]) + 1)
98
+ # hidden_k_1 = ReLU(L * (x - threshold[k]))
99
+ if direction == one_direction:
100
+ hidden = hidden_space.vector_from_basis_direction(
101
+ bases.BasisDirection(f"{hidden_name}start"))
102
+ else:
103
+ hidden = hidden_space.null_vector()
104
+ for k in range(1, num_vals):
105
+ vec0 = hidden_space.vector_from_basis_direction(
106
+ bases.BasisDirection(hidden_name, (k, 0)))
107
+ vec1 = hidden_space.vector_from_basis_direction(
108
+ bases.BasisDirection(hidden_name, (k, 1)))
109
+ if direction == one_direction:
110
+ hidden += (1 - large_number * value_thresholds[k - 1]) * vec0
111
+ hidden -= large_number * value_thresholds[k - 1] * vec1
112
+ else:
113
+ hidden += large_number * vec0 + large_number * vec1
114
+ return hidden
115
+
116
+ return DiscretisingLayerMaterials(
117
+ action=action, hidden_space=hidden_space, output_values=output_values)
118
+
119
+
120
+ def map_numerical_mlp(
121
+ f: Callable[[float], float],
122
+ input_space: bases.VectorSpaceWithBasis,
123
+ output_space: bases.VectorSpaceWithBasis,
124
+ input_value_set: Iterable[float],
125
+ one_space: bases.VectorSpaceWithBasis,
126
+ large_number: float = 100,
127
+ hidden_name: bases.Name = "__hidden__",
128
+ ) -> transformers.MLP:
129
+ """Returns an MLP that encodes any function of a single variable f(x).
130
+
131
+ This is implemented by discretising the input according to input_value_set
132
+ and defining thresholds that determine which part of the input range will
133
+ is allocated to which value in input_value_set.
134
+
135
+ elements of value set: v0 | v1 | v2 | v3 | v4 | ...
136
+ thresholds: t0 t1 t2 t3 t4
137
+
138
+ The MLP computes two hidden activations per threshold:
139
+ hidden_k_0 = ReLU(L * (x - threshold[k]) + 1)
140
+ hidden_k_1 = ReLU(L * (x - threshold[k]))
141
+
142
+ Note that hidden_k_1 - hidden_k_2 is:
143
+ 1 if x >= threshold[k] + 1/L
144
+ 0 if x <= threshold[k]
145
+ between 0 and 1 if threshold[k] < x < threshold[k] + 1/L
146
+
147
+ So as long as we choose L a big enough number, we have
148
+ hidden_k_0 - hidden_k_1 = 1 if x >= threshold[k].
149
+
150
+ The MLP then computes the output as:
151
+ output = f(input[0]) +
152
+ sum((hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1]))
153
+ for all k=0,1,...)
154
+
155
+ This sum will be (by a telescoping sums argument)
156
+ f(input[0]) if x <= threshold[0]
157
+ f(input[k]) if threshold[k-1] < x <= threshold[k] for some other k
158
+ f(input[-1]) if x > threshold[-1]
159
+ which approximates f() up to an accuracy given by input_value_set and L.
160
+
161
+ Args:
162
+ f: Function to approximate.
163
+ input_space: 1-d vector space that encodes the input x.
164
+ output_space: 1-d vector space to write the output to.
165
+ input_value_set: Set of values the input can take.
166
+ one_space: Auxiliary 1-d vector space that must contain 1 in the input.
167
+ large_number: Large number L that determines accuracy of the computation.
168
+ Note that too large values of L can lead to numerical issues, particularly
169
+ during inference on GPU/TPU.
170
+ hidden_name: Name for hidden dimensions.
171
+ """
172
+ bases.ensure_dims(input_space, num_dims=1, name="input_space")
173
+ bases.ensure_dims(output_space, num_dims=1, name="output_space")
174
+ bases.ensure_dims(one_space, num_dims=1, name="one_space")
175
+
176
+ input_space = bases.join_vector_spaces(input_space, one_space)
177
+ out_vec = output_space.vector_from_basis_direction(output_space.basis[0])
178
+
179
+ discretising_layer = _get_discretising_layer(
180
+ input_value_set=input_value_set,
181
+ f=f,
182
+ hidden_name=hidden_name,
183
+ one_direction=one_space.basis[0],
184
+ large_number=large_number)
185
+ first_layer = vectorspace_fns.Linear.from_action(
186
+ input_space, discretising_layer.hidden_space, discretising_layer.action)
187
+
188
+ def second_layer_action(
189
+ direction: bases.BasisDirection) -> bases.VectorInBasis:
190
+ # output = sum(
191
+ # (hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1]))
192
+ # for all k)
193
+ if direction.name == f"{hidden_name}start":
194
+ return discretising_layer.output_values[0] * out_vec
195
+ k, i = direction.value
196
+ # add hidden_k_0 and subtract hidden_k_1
197
+ sign = {0: 1, 1: -1}[i]
198
+ return sign * (discretising_layer.output_values[k] -
199
+ discretising_layer.output_values[k - 1]) * out_vec
200
+
201
+ second_layer = vectorspace_fns.Linear.from_action(
202
+ discretising_layer.hidden_space, output_space, second_layer_action)
203
+
204
+ return transformers.MLP(first_layer, second_layer)
205
+
206
+
207
+ def map_numerical_to_categorical_mlp(
208
+ f: Callable[[float], float],
209
+ input_space: bases.VectorSpaceWithBasis,
210
+ output_space: bases.VectorSpaceWithBasis,
211
+ input_value_set: Iterable[float],
212
+ one_space: bases.VectorSpaceWithBasis,
213
+ large_number: float = 100,
214
+ hidden_name: bases.Name = "__hidden__",
215
+ ) -> transformers.MLP:
216
+ """Returns an MLP to compute f(x) from a numerical to a categorical variable.
217
+
218
+ Uses a set of possible output values, and rounds f(x) to the closest value
219
+ in this set to create a categorical output variable.
220
+
221
+ The output is discretised the same way as in `map_numerical_mlp`.
222
+
223
+ Args:
224
+ f: Function to approximate.
225
+ input_space: 1-d vector space that encodes the input x.
226
+ output_space: n-d vector space to write categorical output to. The output
227
+ directions need to encode the possible output values.
228
+ input_value_set: Set of values the input can take.
229
+ one_space: Auxiliary 1-d space that must contain 1 in the input.
230
+ large_number: Large number L that determines accuracy of the computation.
231
+ hidden_name: Name for hidden dimensions.
232
+ """
233
+ bases.ensure_dims(input_space, num_dims=1, name="input_space")
234
+ bases.ensure_dims(one_space, num_dims=1, name="one_space")
235
+
236
+ input_space = bases.join_vector_spaces(input_space, one_space)
237
+
238
+ vec_by_out_val = dict()
239
+ for d in output_space.basis:
240
+ # TODO(b/255937603): Do a similar assert in other places where we expect
241
+ # categorical basis directions to encode values.
242
+ assert d.value is not None, ("output directions need to encode "
243
+ "possible output values")
244
+ vec_by_out_val[d.value] = output_space.vector_from_basis_direction(d)
245
+
246
+ discretising_layer = _get_discretising_layer(
247
+ input_value_set=input_value_set,
248
+ f=f,
249
+ hidden_name=hidden_name,
250
+ one_direction=one_space.basis[0],
251
+ large_number=large_number)
252
+
253
+ assert set(discretising_layer.output_values).issubset(
254
+ set(vec_by_out_val.keys()))
255
+
256
+ first_layer = vectorspace_fns.Linear.from_action(
257
+ input_space, discretising_layer.hidden_space, discretising_layer.action)
258
+
259
+ def second_layer_action(
260
+ direction: bases.BasisDirection) -> bases.VectorInBasis:
261
+ """Computes output value and returns corresponding output direction."""
262
+ if direction.name == f"{hidden_name}start":
263
+ return vec_by_out_val[discretising_layer.output_values[0]]
264
+ else:
265
+ k, i = direction.value
266
+ # add hidden_k_0 and subtract hidden_k_1
267
+ sign = {0: 1, 1: -1}[i]
268
+ out_k = discretising_layer.output_values[k]
269
+ out_k_m_1 = discretising_layer.output_values[k - 1]
270
+ return sign * (vec_by_out_val[out_k] - vec_by_out_val[out_k_m_1])
271
+
272
+ second_layer = vectorspace_fns.Linear.from_action(
273
+ discretising_layer.hidden_space, output_space, second_layer_action)
274
+
275
+ return transformers.MLP(first_layer, second_layer)
276
+
277
+
278
+ def linear_sequence_map_numerical_mlp(
279
+ input1_basis_direction: bases.BasisDirection,
280
+ input2_basis_direction: bases.BasisDirection,
281
+ output_basis_direction: bases.BasisDirection,
282
+ input1_factor: float,
283
+ input2_factor: float,
284
+ hidden_name: bases.Name = "__hidden__",
285
+ ) -> transformers.MLP:
286
+ """Returns an MLP that encodes a linear function f(x, y) = a*x + b*y.
287
+
288
+ Args:
289
+ input1_basis_direction: Basis direction that encodes the input x.
290
+ input2_basis_direction: Basis direction that encodes the input y.
291
+ output_basis_direction: Basis direction to write the output to.
292
+ input1_factor: Linear factor a for input x.
293
+ input2_factor: Linear factor a for input y.
294
+ hidden_name: Name for hidden dimensions.
295
+ """
296
+ input_space = bases.VectorSpaceWithBasis(
297
+ [input1_basis_direction, input2_basis_direction])
298
+ output_space = bases.VectorSpaceWithBasis([output_basis_direction])
299
+ out_vec = output_space.vector_from_basis_direction(output_basis_direction)
300
+
301
+ hidden_directions = [
302
+ bases.BasisDirection(f"{hidden_name}x", 1),
303
+ bases.BasisDirection(f"{hidden_name}x", -1),
304
+ bases.BasisDirection(f"{hidden_name}y", 1),
305
+ bases.BasisDirection(f"{hidden_name}y", -1)
306
+ ]
307
+ hidden_space = bases.VectorSpaceWithBasis(hidden_directions)
308
+ x_pos_vec, x_neg_vec, y_pos_vec, y_neg_vec = (
309
+ hidden_space.vector_from_basis_direction(d) for d in hidden_directions)
310
+
311
+ def first_layer_action(
312
+ direction: bases.BasisDirection) -> bases.VectorInBasis:
313
+ output = hidden_space.null_vector()
314
+ if direction == input1_basis_direction:
315
+ output += x_pos_vec - x_neg_vec
316
+ if direction == input2_basis_direction:
317
+ output += y_pos_vec - y_neg_vec
318
+ return output
319
+
320
+ first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space,
321
+ first_layer_action)
322
+
323
+ def second_layer_action(
324
+ direction: bases.BasisDirection) -> bases.VectorInBasis:
325
+ if direction.name == f"{hidden_name}x":
326
+ return input1_factor * direction.value * out_vec
327
+ if direction.name == f"{hidden_name}y":
328
+ return input2_factor * direction.value * out_vec
329
+ return output_space.null_vector()
330
+
331
+ second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space,
332
+ second_layer_action)
333
+
334
+ return transformers.MLP(first_layer, second_layer)
craft/chamber/numerical_mlp_test.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for chamber.numerical_mlp."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+ from tracr.craft import bases
21
+ from tracr.craft import tests_common
22
+ from tracr.craft.chamber import numerical_mlp
23
+ from tracr.utils import errors
24
+
25
+
26
+ class NumericalMlpTest(tests_common.VectorFnTestCase):
27
+
28
+ @parameterized.parameters([
29
+ dict(
30
+ in_value_set={-2, -2, -1, 0, 1, 2, 3},
31
+ x=2,
32
+ function=lambda x: x,
33
+ result=2),
34
+ dict(
35
+ in_value_set={-2, -2, -1, 0, 1, 2, 3},
36
+ x=2,
37
+ function=lambda x: x**2,
38
+ result=4),
39
+ dict(
40
+ in_value_set={-2, -2, -1, 0, 1, 2, 3},
41
+ x=2,
42
+ function=lambda x: x**3,
43
+ result=8),
44
+ dict(
45
+ in_value_set={-2, -2, -1, 0, 1, 2, 3},
46
+ x=-2,
47
+ function=lambda x: x,
48
+ result=-2),
49
+ dict(
50
+ in_value_set={-2, -2, -1, 0, 1, 2, 3},
51
+ x=-2,
52
+ function=lambda x: x**2,
53
+ result=4),
54
+ dict(
55
+ in_value_set={-2, -2, -1, 0, 1, 2, 3},
56
+ x=-2,
57
+ function=lambda x: x**3,
58
+ result=-8),
59
+ ])
60
+ def test_map_numerical_mlp_produces_expected_outcome(self, in_value_set, x,
61
+ function, result):
62
+
63
+ input_dir = bases.BasisDirection("input")
64
+ output_dir = bases.BasisDirection("output")
65
+ one_dir = bases.BasisDirection("one")
66
+ input_space = bases.VectorSpaceWithBasis([input_dir])
67
+ output_space = bases.VectorSpaceWithBasis([output_dir])
68
+ one_space = bases.VectorSpaceWithBasis([one_dir])
69
+
70
+ mlp = numerical_mlp.map_numerical_mlp(
71
+ f=function,
72
+ input_space=input_space,
73
+ output_space=output_space,
74
+ one_space=one_space,
75
+ input_value_set=in_value_set,
76
+ )
77
+
78
+ test_inputs = bases.VectorInBasis(
79
+ basis_directions=[input_dir, output_dir, one_dir],
80
+ magnitudes=np.array([x, 0, 1]))
81
+
82
+ expected_results = bases.VectorInBasis(
83
+ basis_directions=[input_dir, output_dir, one_dir],
84
+ magnitudes=np.array([0, result, 0]))
85
+
86
+ test_outputs = mlp.apply(test_inputs)
87
+
88
+ self.assertVectorAllClose(test_outputs, expected_results)
89
+
90
+ @parameterized.parameters([
91
+ dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1),
92
+ dict(
93
+ in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5),
94
+ dict(
95
+ in_value_set={0, 1, 2, 3},
96
+ x=3,
97
+ function=lambda x: 1 / x,
98
+ result=1 / 3),
99
+ ])
100
+ def test_map_numerical_mlp_logs_warning_and_produces_expected_outcome(
101
+ self, in_value_set, x, function, result):
102
+
103
+ input_dir = bases.BasisDirection("input")
104
+ output_dir = bases.BasisDirection("output")
105
+ one_dir = bases.BasisDirection("one")
106
+ input_space = bases.VectorSpaceWithBasis([input_dir])
107
+ output_space = bases.VectorSpaceWithBasis([output_dir])
108
+ one_space = bases.VectorSpaceWithBasis([one_dir])
109
+
110
+ with self.assertLogs(level="WARNING"):
111
+ mlp = numerical_mlp.map_numerical_mlp(
112
+ f=function,
113
+ input_space=input_space,
114
+ output_space=output_space,
115
+ one_space=one_space,
116
+ input_value_set=in_value_set,
117
+ )
118
+
119
+ test_inputs = bases.VectorInBasis(
120
+ basis_directions=[input_dir, output_dir, one_dir],
121
+ magnitudes=np.array([x, 0, 1]))
122
+
123
+ expected_results = bases.VectorInBasis(
124
+ basis_directions=[input_dir, output_dir, one_dir],
125
+ magnitudes=np.array([0, result, 0]))
126
+
127
+ test_outputs = mlp.apply(test_inputs)
128
+
129
+ self.assertVectorAllClose(test_outputs, expected_results)
130
+
131
+ @parameterized.parameters([
132
+ dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1),
133
+ dict(
134
+ in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5),
135
+ dict(
136
+ in_value_set={0, 1, 2, 3},
137
+ x=3,
138
+ function=lambda x: 1 / x,
139
+ result=1 / 3),
140
+ ])
141
+ def test_map_numerical_to_categorical_mlp_logs_warning_and_produces_expected_outcome(
142
+ self, in_value_set, x, function, result):
143
+
144
+ f_ign = errors.ignoring_arithmetic_errors(function)
145
+ out_value_set = {f_ign(x) for x in in_value_set if f_ign(x) is not None}
146
+
147
+ in_space = bases.VectorSpaceWithBasis.from_names(["input"])
148
+ out_space = bases.VectorSpaceWithBasis.from_values("output", out_value_set)
149
+ one_space = bases.VectorSpaceWithBasis.from_names(["one"])
150
+
151
+ residual_space = bases.join_vector_spaces(in_space, one_space, out_space)
152
+ in_vec = residual_space.vector_from_basis_direction(in_space.basis[0])
153
+ one_vec = residual_space.vector_from_basis_direction(one_space.basis[0])
154
+
155
+ with self.assertLogs(level="WARNING"):
156
+ mlp = numerical_mlp.map_numerical_to_categorical_mlp(
157
+ f=function,
158
+ input_space=in_space,
159
+ output_space=out_space,
160
+ input_value_set=in_value_set,
161
+ one_space=one_space,
162
+ )
163
+
164
+ test_inputs = x * in_vec + one_vec
165
+ expected_results = out_space.vector_from_basis_direction(
166
+ bases.BasisDirection("output", result))
167
+ test_outputs = mlp.apply(test_inputs).project(out_space)
168
+ self.assertVectorAllClose(test_outputs, expected_results)
169
+
170
+ @parameterized.parameters([
171
+ dict(x_factor=1, y_factor=2, x=1, y=1, result=3),
172
+ dict(x_factor=1, y_factor=2, x=1, y=-1, result=-1),
173
+ dict(x_factor=1, y_factor=-1, x=1, y=1, result=0),
174
+ dict(x_factor=1, y_factor=1, x=3, y=5, result=8),
175
+ dict(x_factor=-2, y_factor=-0.5, x=4, y=1, result=-8.5),
176
+ ])
177
+ def test_linear_sequence_map_produces_expected_result(self, x_factor,
178
+ y_factor, x, y, result):
179
+
180
+ input1_dir = bases.BasisDirection("input1")
181
+ input2_dir = bases.BasisDirection("input2")
182
+ output_dir = bases.BasisDirection("output")
183
+
184
+ mlp = numerical_mlp.linear_sequence_map_numerical_mlp(
185
+ input1_basis_direction=input1_dir,
186
+ input2_basis_direction=input2_dir,
187
+ output_basis_direction=output_dir,
188
+ input1_factor=x_factor,
189
+ input2_factor=y_factor)
190
+
191
+ test_inputs = bases.VectorInBasis(
192
+ basis_directions=[input1_dir, input2_dir, output_dir],
193
+ magnitudes=np.array([x, y, 0]))
194
+
195
+ expected_results = bases.VectorInBasis(
196
+ basis_directions=[input1_dir, input2_dir, output_dir],
197
+ magnitudes=np.array([0, 0, result]))
198
+
199
+ test_outputs = mlp.apply(test_inputs)
200
+
201
+ self.assertVectorAllClose(test_outputs, expected_results)
202
+
203
+ @parameterized.parameters([
204
+ dict(x_factor=1, y_factor=2, x=1, result=3),
205
+ dict(x_factor=1, y_factor=-1, x=1, result=0),
206
+ ])
207
+ def test_linear_sequence_map_produces_expected_result_with_same_inputs(
208
+ self, x_factor, y_factor, x, result):
209
+
210
+ input_dir = bases.BasisDirection("input")
211
+ output_dir = bases.BasisDirection("output")
212
+
213
+ mlp = numerical_mlp.linear_sequence_map_numerical_mlp(
214
+ input1_basis_direction=input_dir,
215
+ input2_basis_direction=input_dir,
216
+ output_basis_direction=output_dir,
217
+ input1_factor=x_factor,
218
+ input2_factor=y_factor)
219
+
220
+ test_inputs = bases.VectorInBasis(
221
+ basis_directions=[input_dir, output_dir], magnitudes=np.array([x, 0]))
222
+
223
+ expected_results = bases.VectorInBasis(
224
+ basis_directions=[input_dir, output_dir],
225
+ magnitudes=np.array([0, result]))
226
+
227
+ test_outputs = mlp.apply(test_inputs)
228
+
229
+ self.assertVectorAllClose(test_outputs, expected_results)
230
+
231
+
232
+ if __name__ == "__main__":
233
+ absltest.main()
craft/chamber/selector_width.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """SelectorWidth component consisting of an attention head and an MLP."""
16
+
17
+ from typing import Iterable
18
+ from tracr.craft import bases
19
+ from tracr.craft import transformers
20
+ from tracr.craft import vectorspace_fns
21
+ from tracr.craft.chamber import categorical_attn
22
+ from tracr.craft.chamber import numerical_mlp
23
+
24
+
25
+ def selector_width(
26
+ query_space: bases.VectorSpaceWithBasis,
27
+ key_space: bases.VectorSpaceWithBasis,
28
+ output_space: bases.VectorSpaceWithBasis,
29
+ bos_space: bases.VectorSpaceWithBasis,
30
+ one_space: bases.VectorSpaceWithBasis,
31
+ attn_fn: categorical_attn.QueryKeyToAttnLogit,
32
+ out_value_set: Iterable[float],
33
+ categorical_output: bool,
34
+ causal: bool = False,
35
+ softmax_coldness: float = 100.,
36
+ mlp_large_number: float = 100.,
37
+ label: str = "",
38
+ ) -> transformers.SeriesWithResiduals:
39
+ """Returns a craft block implementing RASP's SelectorWidth primitive.
40
+
41
+ The block consists of one attention head and one MLP.
42
+
43
+ The attention head implements the attention pattern (attn_fn or key=bos) and
44
+ aggregates the bos dimension over this pattern. The output of this will be
45
+ 1/(d+1) in every position, where d is the "width" of the attention pattern,
46
+ i.e. the number of 1s in a row.
47
+
48
+ The MLP then computes d from the previous output in all positions except for
49
+ the first BOS position. In the BOS position the MLP removes the output of the
50
+ attention head, to ensure it only contains the encoding of the BOS token
51
+ which is expected by all other model components.
52
+
53
+ Args:
54
+ query_space: Vector space containing (categorical) query input.
55
+ key_space: Vector space containing (categorical) key input.
56
+ output_space: Vector space which will contain (numerical or categorical)
57
+ output.
58
+ bos_space: 1-d space used to identify the beginning of sequence token.
59
+ one_space: Auxiliary 1-d vector space that must contain 1 in the input.
60
+ attn_fn: A selector function f(query, key) operating on the query/key basis
61
+ directions that defines the attention pattern to compute the width of.
62
+ out_value_set: Set of possible output values of this SelectorWidth.
63
+ categorical_output: If True, encode the output as a categorical variable.
64
+ causal: If True, use masked attention.
65
+ softmax_coldness: The inverse temperature of the softmax. Default value is
66
+ high which makes the attention close to a hard maximum.
67
+ mlp_large_number: A larger number makes the MLP more accurate.
68
+ label: A name for this block, used to label auxiliary dimensions.
69
+ """
70
+ assert output_space.num_dims == 1 or categorical_output
71
+
72
+ attn_out_dir = bases.BasisDirection(f"{label}_selector_width_attn_output")
73
+ attn_out_space = bases.VectorSpaceWithBasis([attn_out_dir])
74
+ attn_out_vec = attn_out_space.vector_from_basis_direction(attn_out_dir)
75
+
76
+ attn = categorical_attn.categorical_attn(
77
+ query_space=query_space,
78
+ key_space=key_space,
79
+ value_space=bos_space,
80
+ output_space=attn_out_space,
81
+ bos_space=bos_space,
82
+ one_space=one_space,
83
+ attn_fn=attn_fn,
84
+ default_output=attn_out_space.null_vector(),
85
+ causal=causal,
86
+ always_attend_to_bos=True,
87
+ use_bos_for_default_output=False,
88
+ softmax_coldness=softmax_coldness)
89
+
90
+ fun = lambda x: (1 / x) - 1
91
+ in_value_set = {1 / (x + 1) for x in out_value_set}
92
+ if categorical_output:
93
+ mlp = numerical_mlp.map_numerical_to_categorical_mlp(
94
+ f=fun,
95
+ input_space=attn_out_space,
96
+ output_space=output_space,
97
+ input_value_set=in_value_set,
98
+ one_space=one_space,
99
+ hidden_name=f"_hidden_{label}_",
100
+ large_number=mlp_large_number)
101
+ else:
102
+ mlp = numerical_mlp.map_numerical_mlp(
103
+ f=fun,
104
+ input_space=attn_out_space,
105
+ output_space=output_space,
106
+ input_value_set=in_value_set,
107
+ one_space=one_space,
108
+ hidden_name=f"_hidden_{label}_",
109
+ large_number=mlp_large_number)
110
+
111
+ # This implementation of selector width writes at each position including
112
+ # the BOS. To ensure that the BOS token position does not contain
113
+ # additional values, we add an mlp to subtract the output of both layers.
114
+ clean_bos_out_space = bases.join_vector_spaces(attn_out_space, output_space)
115
+ vec_to_subtract_from_bos = attn_out_vec.project(clean_bos_out_space)
116
+
117
+ if categorical_output:
118
+ # Add the one-hot encoding of the zero value to the vector
119
+ # which will get scrubbed from the BOS position.
120
+ zero_dir = [d for d in output_space.basis if d.value == 0][0]
121
+ zero_vec = clean_bos_out_space.vector_from_basis_direction(zero_dir)
122
+ vec_to_subtract_from_bos += zero_vec
123
+
124
+ # Construct an MLP that subtracts vec_to_subtract_from_bos * bos
125
+ # from the residual stream which is vec_to_subtract_from_bos in the
126
+ # bos position and 0 else. vec_to_subtract_from_bos contains what the
127
+ # attention head writes to the bos position.
128
+
129
+ hidden_dir = bases.BasisDirection("_hidden_clean_bos_")
130
+ hidden_space = bases.VectorSpaceWithBasis([hidden_dir])
131
+ hidden_vec = hidden_space.vector_from_basis_direction(hidden_dir)
132
+
133
+ # It's okay to use the local variables because they are only used within
134
+ # the same loop iteration to create the MLP.
135
+ # pylint: disable=cell-var-from-loop
136
+ first_layer = vectorspace_fns.Linear.from_action(bos_space, hidden_space,
137
+ lambda x: hidden_vec)
138
+ second_layer = vectorspace_fns.Linear.from_action(
139
+ hidden_space, clean_bos_out_space, lambda x: -vec_to_subtract_from_bos)
140
+ # pylint: enable=cell-var-from-loop
141
+ clean_bos_mlp = transformers.MLP(first_layer, second_layer)
142
+
143
+ mlp = transformers.MLP.combine_in_parallel([mlp, clean_bos_mlp])
144
+ return transformers.SeriesWithResiduals([attn, mlp])
craft/chamber/selector_width_test.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for selector_width."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from tracr.craft import bases
20
+ from tracr.craft import tests_common
21
+ from tracr.craft.chamber import selector_width
22
+
23
+
24
+ class SelectorWidthTest(tests_common.VectorFnTestCase):
25
+
26
+ @parameterized.product(
27
+ causal=[False, True],
28
+ categorical_output=[False, True],
29
+ input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]])
30
+ def test_selector_width_of_select_all_is_length(self, causal,
31
+ categorical_output,
32
+ input_seq):
33
+ vocab = range(-20, 20)
34
+ input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
35
+
36
+ if categorical_output:
37
+ output_space = bases.VectorSpaceWithBasis.from_values("output", range(10))
38
+ else:
39
+ output_space = bases.VectorSpaceWithBasis(
40
+ [bases.BasisDirection("output")])
41
+
42
+ bos_dir = bases.BasisDirection("bos_dimension")
43
+ bos_space = bases.VectorSpaceWithBasis([bos_dir])
44
+
45
+ one_dir = bases.BasisDirection("one_dimension")
46
+ one_space = bases.VectorSpaceWithBasis([one_dir])
47
+
48
+ input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
49
+ residual_space = bases.join_vector_spaces(input_space, output_space)
50
+ bos_vec = residual_space.vector_from_basis_direction(bos_dir)
51
+ one_vec = residual_space.vector_from_basis_direction(one_dir)
52
+
53
+ block = selector_width.selector_width(
54
+ query_space=input_space,
55
+ key_space=input_space,
56
+ output_space=output_space,
57
+ bos_space=bos_space,
58
+ one_space=one_space,
59
+ attn_fn=lambda x, y: True,
60
+ out_value_set=set(range(len(input_seq) + 1)),
61
+ categorical_output=categorical_output,
62
+ causal=causal,
63
+ label="select_all")
64
+
65
+ test_inputs = [bos_vec + one_vec]
66
+ for x in input_seq:
67
+ test_inputs.append(
68
+ residual_space.vector_from_basis_direction(
69
+ bases.BasisDirection("input", x)) + one_vec)
70
+ test_inputs = bases.VectorInBasis.stack(test_inputs)
71
+
72
+ # Expect length of the input sequence
73
+ if causal:
74
+ expected_results = list(range(1, len(input_seq) + 1))
75
+ else:
76
+ expected_results = [len(input_seq) for _ in input_seq]
77
+
78
+ if categorical_output:
79
+ expected_results = [
80
+ output_space.vector_from_basis_direction(
81
+ bases.BasisDirection("output", x)) for x in expected_results
82
+ ]
83
+ else:
84
+ output_vec = output_space.vector_from_basis_direction(
85
+ bases.BasisDirection("output"))
86
+ expected_results = [x * output_vec for x in expected_results]
87
+
88
+ expected_results = bases.VectorInBasis.stack(expected_results)
89
+
90
+ test_outputs = block.apply(test_inputs).project(output_space)
91
+ self.assertVectorAllClose(
92
+ tests_common.strip_bos_token(test_outputs), expected_results)
93
+
94
+ @parameterized.product(
95
+ causal=[False, True],
96
+ categorical_output=[False, True],
97
+ input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]])
98
+ def test_selector_width_of_select_none_is_zero(self, causal,
99
+ categorical_output, input_seq):
100
+ vocab = range(-20, 20)
101
+ input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
102
+
103
+ if categorical_output:
104
+ output_space = bases.VectorSpaceWithBasis.from_values("output", range(10))
105
+ else:
106
+ output_space = bases.VectorSpaceWithBasis(
107
+ [bases.BasisDirection("output")])
108
+
109
+ bos_dir = bases.BasisDirection("bos_dimension")
110
+ bos_space = bases.VectorSpaceWithBasis([bos_dir])
111
+
112
+ one_dir = bases.BasisDirection("one_dimension")
113
+ one_space = bases.VectorSpaceWithBasis([one_dir])
114
+
115
+ input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
116
+ residual_space = bases.join_vector_spaces(input_space, output_space)
117
+ bos_vec = residual_space.vector_from_basis_direction(bos_dir)
118
+ one_vec = residual_space.vector_from_basis_direction(one_dir)
119
+
120
+ block = selector_width.selector_width(
121
+ query_space=input_space,
122
+ key_space=input_space,
123
+ output_space=output_space,
124
+ bos_space=bos_space,
125
+ one_space=one_space,
126
+ attn_fn=lambda x, y: False,
127
+ out_value_set=set(range(len(input_seq) + 1)),
128
+ categorical_output=categorical_output,
129
+ causal=causal,
130
+ label="select_all")
131
+
132
+ test_inputs = [bos_vec + one_vec]
133
+ for x in input_seq:
134
+ test_inputs.append(
135
+ residual_space.vector_from_basis_direction(
136
+ bases.BasisDirection("input", x)) + one_vec)
137
+ test_inputs = bases.VectorInBasis.stack(test_inputs)
138
+
139
+ # Expect zero output
140
+ if categorical_output:
141
+ expected_results = [
142
+ output_space.vector_from_basis_direction(
143
+ bases.BasisDirection("output", 0)) for _ in input_seq
144
+ ]
145
+ else:
146
+ expected_results = [output_space.null_vector() for _ in input_seq]
147
+ expected_results = bases.VectorInBasis.stack(expected_results)
148
+
149
+ test_outputs = block.apply(test_inputs).project(output_space)
150
+ self.assertVectorAllClose(
151
+ tests_common.strip_bos_token(test_outputs), expected_results)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ absltest.main()
craft/tests_common.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Helper functions for tests."""
16
+
17
+ from absl.testing import parameterized
18
+ import numpy as np
19
+ from tracr.craft import bases
20
+
21
+
22
+ def strip_bos_token(vector: bases.VectorInBasis) -> bases.VectorInBasis:
23
+ """Removes BOS token of a vector."""
24
+ return bases.VectorInBasis(vector.basis_directions, vector.magnitudes[1:])
25
+
26
+
27
+ class VectorFnTestCase(parameterized.TestCase):
28
+ """Asserts for vectors."""
29
+
30
+ def assertVectorAllClose(self, v1: bases.VectorInBasis,
31
+ v2: bases.VectorInBasis):
32
+ self.assertEqual(v1.basis_directions, v2.basis_directions)
33
+ np.testing.assert_allclose(v1.magnitudes, v2.magnitudes, atol=1e-7)
craft/transformers.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pieces for making transformers."""
16
+
17
+ import abc
18
+ import dataclasses
19
+ from typing import Iterable, Optional, Sequence, Union
20
+
21
+ import numpy as np
22
+
23
+ from tracr.craft import bases
24
+ from tracr.craft import vectorspace_fns
25
+
26
+ project = vectorspace_fns.project
27
+
28
+
29
+ def _np_softmax(x, axis=-1):
30
+ x_max = np.max(x, axis=axis, keepdims=True)
31
+ return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True)
32
+
33
+
34
+ def _np_relu(x):
35
+ return np.where(x > 0, x, 0)
36
+
37
+
38
+ def relu(x: bases.VectorInBasis) -> bases.VectorInBasis:
39
+ return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes))
40
+
41
+
42
+ class Block(abc.ABC):
43
+ """Transformer block, acting on a sequence of vector space elements.
44
+
45
+ Attributes:
46
+ residual_space: Vector space that contains all subspaces the Block interacts
47
+ with. This can be either the full residual space of a model or a subspace.
48
+ """
49
+ residual_space: bases.VectorSpaceWithBasis
50
+
51
+ @abc.abstractmethod
52
+ def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
53
+ """Applies self to an input."""
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class AttentionHead(Block):
58
+ """A transformer attention head."""
59
+ w_qk: vectorspace_fns.ScalarBilinear
60
+ w_ov: vectorspace_fns.Linear
61
+ residual_space: Optional[bases.VectorSpaceWithBasis] = None
62
+ causal: bool = False
63
+
64
+ def __post_init__(self):
65
+ """Infer residual stream and typecheck subspaces."""
66
+ if self.residual_space is None:
67
+ self.residual_space = bases.join_vector_spaces(self.w_qk.left_space,
68
+ self.w_qk.right_space,
69
+ self.w_ov.input_space,
70
+ self.w_ov.output_space)
71
+
72
+ assert self.w_qk.left_space.issubspace(self.residual_space)
73
+ assert self.w_qk.right_space.issubspace(self.residual_space)
74
+ assert self.w_ov.input_space.issubspace(self.residual_space)
75
+ assert self.w_ov.output_space.issubspace(self.residual_space)
76
+
77
+ def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
78
+ assert x in self.residual_space
79
+ # seq_len x query_space
80
+ queries = x.project(self.w_qk.left_space)
81
+ # seq_len x key_space
82
+ keys = x.project(self.w_qk.right_space)
83
+
84
+ attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T
85
+
86
+ if self.causal:
87
+ # The 1 gives us the matrix above the diagonal.
88
+ mask = np.triu(np.full_like(attn_matrix, -np.inf), 1)
89
+ attn_matrix = attn_matrix + mask
90
+
91
+ attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to
92
+ values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model
93
+
94
+ magnitudes = attn_weights @ values # seq_len_from, d_model
95
+ return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes)
96
+
97
+ def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
98
+ """Wov but acting on the residual space."""
99
+ x = project(self.residual_space, self.w_ov.input_space)(x)
100
+ out = self.w_ov(x)
101
+ return project(self.w_ov.output_space, self.residual_space)(out)
102
+
103
+ @property
104
+ def num_heads(self) -> int:
105
+ return 1
106
+
107
+ def as_multi(self) -> "MultiAttentionHead":
108
+ return MultiAttentionHead([self])
109
+
110
+
111
+ @dataclasses.dataclass
112
+ class MultiAttentionHead(Block):
113
+ """Applies attention heads in parallel."""
114
+ sub_blocks: list[Union[AttentionHead, "MultiAttentionHead"]]
115
+
116
+ def __post_init__(self):
117
+ spaces = [block.residual_space for block in self.sub_blocks]
118
+ self.residual_space, *others = spaces
119
+ assert all(s == self.residual_space for s in others)
120
+
121
+ def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
122
+ # each element is seq_len x embedding
123
+ outs = [block.apply(x) for block in self.sub_blocks]
124
+ return bases.VectorInBasis.sum(outs) # seq_len x embedding
125
+
126
+ @property
127
+ def num_heads(self) -> int:
128
+ return sum(sub_block.num_heads for sub_block in self.sub_blocks)
129
+
130
+ def heads(self) -> Iterable[AttentionHead]:
131
+ for sub_block in self.sub_blocks:
132
+ if isinstance(sub_block, AttentionHead):
133
+ yield sub_block
134
+ elif isinstance(sub_block, MultiAttentionHead):
135
+ yield from sub_block.heads()
136
+ else:
137
+ raise NotImplementedError()
138
+
139
+ def as_multi(self) -> "MultiAttentionHead":
140
+ return self
141
+
142
+
143
+ @dataclasses.dataclass
144
+ class MLP(Block):
145
+ """A transformer MLP block."""
146
+ fst: vectorspace_fns.Linear
147
+ snd: vectorspace_fns.Linear
148
+ residual_space: Optional[bases.VectorSpaceWithBasis] = None
149
+
150
+ def __post_init__(self):
151
+ """Typecheck subspaces."""
152
+ if self.residual_space is None:
153
+ self.residual_space = bases.join_vector_spaces(self.fst.input_space,
154
+ self.snd.output_space)
155
+
156
+ assert self.fst.output_space == self.snd.input_space
157
+ assert self.fst.input_space.issubspace(self.residual_space)
158
+ assert self.snd.output_space.issubspace(self.residual_space)
159
+
160
+ def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
161
+ assert x in self.residual_space
162
+
163
+ x = project(self.residual_space, self.fst.input_space)(x)
164
+ hidden = self.fst(x)
165
+ hidden = relu(hidden)
166
+ out = self.snd(hidden)
167
+ return project(self.snd.output_space, self.residual_space)(out)
168
+
169
+ @classmethod
170
+ def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP":
171
+ fst = vectorspace_fns.Linear.combine_in_parallel(
172
+ [block.fst for block in mlps])
173
+ snd = vectorspace_fns.Linear.combine_in_parallel(
174
+ [block.snd for block in mlps])
175
+ return cls(fst=fst, snd=snd, residual_space=None)
176
+
177
+
178
+ # Block that fits into a half-layer, without residual connections.
179
+ HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead]
180
+
181
+
182
+ @dataclasses.dataclass
183
+ class SeriesWithResiduals(Block):
184
+ """A series of blocks with residual connections."""
185
+ blocks: list[HalfLayerBlock]
186
+
187
+ def __post_init__(self):
188
+ spaces = [block.residual_space for block in self.blocks]
189
+ self.residual_space = bases.join_vector_spaces(*spaces)
190
+
191
+ def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
192
+ x = x.project(self.residual_space)
193
+ for block in self.blocks:
194
+ x_in = x.project(block.residual_space)
195
+ x_out = block.apply(x_in).project(self.residual_space)
196
+ x = x + x_out
197
+ return x
craft/transformers_test.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for transformers."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+ from tracr.craft import bases
21
+ from tracr.craft import tests_common
22
+ from tracr.craft import transformers
23
+ from tracr.craft import vectorspace_fns as vs_fns
24
+
25
+ # This makes it easier to use comments to annotate dimensions in arrays
26
+ # pylint: disable=g-no-space-after-comment
27
+
28
+
29
+ class AttentionHeadTest(tests_common.VectorFnTestCase):
30
+
31
+ @parameterized.parameters([
32
+ dict(with_residual_stream=False),
33
+ dict(with_residual_stream=True),
34
+ ])
35
+ def test_attention_head(self, with_residual_stream):
36
+ i = bases.VectorSpaceWithBasis.from_values("i", [1, 2])
37
+ o = bases.VectorSpaceWithBasis.from_values("o", [1, 2])
38
+ q = bases.VectorSpaceWithBasis.from_values("q", [1, 2])
39
+ k = bases.VectorSpaceWithBasis.from_values("p", [1, 2])
40
+ rs = bases.direct_sum(i, o, q, k)
41
+
42
+ seq = bases.VectorInBasis(
43
+ rs.basis,
44
+ np.array([
45
+ #i1 i2 o1 o2 q1 q2 p1 p2
46
+ [1, 0, 0, 0, 1, 0, 1, 0],
47
+ [0, 1, 0, 0, 0, 1, 0, 1],
48
+ ]))
49
+
50
+ head = transformers.AttentionHead(
51
+ w_qk=vs_fns.ScalarBilinear(q, k,
52
+ np.eye(2) * 100),
53
+ w_ov=vs_fns.Linear(i, o, np.eye(2)),
54
+ residual_space=rs if with_residual_stream else None,
55
+ causal=False,
56
+ )
57
+
58
+ self.assertVectorAllClose(
59
+ head.apply(seq),
60
+ bases.VectorInBasis(
61
+ rs.basis,
62
+ np.array([
63
+ #i1 i2 o1 o2 q1 q2 p1 p2
64
+ [0, 0, 1, 0, 0, 0, 0, 0],
65
+ [0, 0, 0, 1, 0, 0, 0, 0],
66
+ ])),
67
+ )
68
+
69
+
70
+ class MLPTest(tests_common.VectorFnTestCase):
71
+
72
+ @parameterized.parameters([
73
+ dict(with_residual_stream=False, same_in_out=False),
74
+ dict(with_residual_stream=False, same_in_out=True),
75
+ dict(with_residual_stream=True, same_in_out=False),
76
+ dict(with_residual_stream=True, same_in_out=True),
77
+ ])
78
+ def test_mlp(self, with_residual_stream, same_in_out):
79
+ i = bases.VectorSpaceWithBasis.from_values("i", [1, 2])
80
+ if same_in_out:
81
+ o, rs = i, i
82
+ expected_result = np.array([
83
+ #o1 o2
84
+ [1, 0],
85
+ [0, 1],
86
+ ])
87
+ else:
88
+ o = bases.VectorSpaceWithBasis.from_values("o", [1, 2])
89
+ rs = bases.direct_sum(i, o)
90
+ expected_result = np.array([
91
+ #i1 i2 o1 o2
92
+ [0, 0, 1, 0],
93
+ [0, 0, 0, 1],
94
+ ])
95
+ h = bases.VectorSpaceWithBasis.from_values("p", [1, 2])
96
+
97
+ seq = bases.VectorInBasis(
98
+ i.basis,
99
+ np.array([
100
+ #i1 i2
101
+ [1, -1],
102
+ [-1, 1],
103
+ ])).project(rs)
104
+
105
+ mlp = transformers.MLP(
106
+ fst=vs_fns.Linear(i, h, np.eye(2)),
107
+ snd=vs_fns.Linear(h, o, np.eye(2)),
108
+ residual_space=rs if with_residual_stream else None,
109
+ )
110
+
111
+ self.assertEqual(
112
+ mlp.apply(seq),
113
+ bases.VectorInBasis(rs.basis, expected_result),
114
+ )
115
+
116
+ def test_combining_mlps(self):
117
+ in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2])
118
+ in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4])
119
+ out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2])
120
+ residual_space = bases.join_vector_spaces(in12, in34, out12)
121
+
122
+ h1 = bases.VectorSpaceWithBasis.from_values("h", [1])
123
+ h2 = bases.VectorSpaceWithBasis.from_values("h", [2])
124
+
125
+ # MLP1 maps in2 -> h1 -> out1
126
+ mlp1 = transformers.MLP(
127
+ fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])),
128
+ snd=vs_fns.Linear(h1, out12, np.array([[1, 0]])))
129
+
130
+ # MLP2 maps in3 -> h2 -> out2
131
+ mlp2 = transformers.MLP(
132
+ fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])),
133
+ snd=vs_fns.Linear(h2, out12, np.array([[0, 1]])))
134
+
135
+ mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2])
136
+
137
+ seq = bases.VectorInBasis(
138
+ bases.direct_sum(in12, in34).basis,
139
+ np.array([
140
+ #i1 i2 i3 i4
141
+ [1, 2, 0, 0],
142
+ [0, 2, 3, 4],
143
+ ])).project(residual_space)
144
+
145
+ expected_result = bases.VectorInBasis(
146
+ out12.basis,
147
+ np.array([
148
+ #o1 o2
149
+ [2, 0],
150
+ [2, 3],
151
+ ]))
152
+
153
+ self.assertEqual(
154
+ mlp.apply(seq).project(out12),
155
+ expected_result,
156
+ )
157
+
158
+
159
+ if __name__ == "__main__":
160
+ absltest.main()
craft/vectorspace_fns.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Functions on vector spaces."""
16
+
17
+ import abc
18
+ import dataclasses
19
+ from typing import Callable, Sequence
20
+
21
+ import numpy as np
22
+
23
+ from tracr.craft import bases
24
+
25
+ VectorSpaceWithBasis = bases.VectorSpaceWithBasis
26
+ VectorInBasis = bases.VectorInBasis
27
+ BasisDirection = bases.BasisDirection
28
+
29
+
30
+ class VectorFunction(abc.ABC):
31
+ """A function that acts on vectors."""
32
+
33
+ input_space: VectorSpaceWithBasis
34
+ output_space: VectorSpaceWithBasis
35
+
36
+ @abc.abstractmethod
37
+ def __call__(self, x: VectorInBasis) -> VectorInBasis:
38
+ """Evaluates the function."""
39
+
40
+
41
+ class Linear(VectorFunction):
42
+ """A linear function."""
43
+
44
+ def __init__(
45
+ self,
46
+ input_space: VectorSpaceWithBasis,
47
+ output_space: VectorSpaceWithBasis,
48
+ matrix: np.ndarray,
49
+ ):
50
+ """Initialises.
51
+
52
+ Args:
53
+ input_space: The input vector space.
54
+ output_space: The output vector space.
55
+ matrix: a [input, output] matrix acting in a (sorted) basis.
56
+ """
57
+ self.input_space = input_space
58
+ self.output_space = output_space
59
+ self.matrix = matrix
60
+
61
+ def __post_init__(self) -> None:
62
+ output_size, input_size = self.matrix.shape
63
+ assert input_size == self.input_space.num_dims
64
+ assert output_size == self.output_space.num_dims
65
+
66
+ def __call__(self, x: VectorInBasis) -> VectorInBasis:
67
+ if x not in self.input_space:
68
+ raise TypeError(f"{x=} not in {self.input_space=}.")
69
+ return VectorInBasis(
70
+ basis_directions=sorted(self.output_space.basis),
71
+ magnitudes=x.magnitudes @ self.matrix,
72
+ )
73
+
74
+ @classmethod
75
+ def from_action(
76
+ cls,
77
+ input_space: VectorSpaceWithBasis,
78
+ output_space: VectorSpaceWithBasis,
79
+ action: Callable[[BasisDirection], VectorInBasis],
80
+ ) -> "Linear":
81
+ """from_action(i, o)(action) creates a Linear."""
82
+
83
+ matrix = np.zeros((input_space.num_dims, output_space.num_dims))
84
+ for i, direction in enumerate(input_space.basis):
85
+ out_vector = action(direction)
86
+ if out_vector not in output_space:
87
+ raise TypeError(f"image of {direction} from {input_space=} "
88
+ f"is not in {output_space=}")
89
+ matrix[i, :] = out_vector.magnitudes
90
+
91
+ return Linear(input_space, output_space, matrix)
92
+
93
+ @classmethod
94
+ def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear":
95
+ """Combines multiple parallel linear functions into a single one."""
96
+ joint_input_space = bases.join_vector_spaces(
97
+ *[fn.input_space for fn in fns])
98
+ joint_output_space = bases.join_vector_spaces(
99
+ *[fn.output_space for fn in fns])
100
+
101
+ def action(x: bases.BasisDirection) -> bases.VectorInBasis:
102
+ out = joint_output_space.null_vector()
103
+ for fn in fns:
104
+ if x in fn.input_space:
105
+ x_vec = fn.input_space.vector_from_basis_direction(x)
106
+ out += fn(x_vec).project(joint_output_space)
107
+ return out
108
+
109
+ return cls.from_action(joint_input_space, joint_output_space, action)
110
+
111
+
112
+ def project(
113
+ from_space: VectorSpaceWithBasis,
114
+ to_space: VectorSpaceWithBasis,
115
+ ) -> Linear:
116
+ """Creates a projection."""
117
+
118
+ def action(direction: bases.BasisDirection) -> VectorInBasis:
119
+ if direction in to_space:
120
+ return to_space.vector_from_basis_direction(direction)
121
+ else:
122
+ return to_space.null_vector()
123
+
124
+ return Linear.from_action(from_space, to_space, action=action)
125
+
126
+
127
+ @dataclasses.dataclass
128
+ class ScalarBilinear:
129
+ """A scalar-valued bilinear operator."""
130
+ left_space: VectorSpaceWithBasis
131
+ right_space: VectorSpaceWithBasis
132
+ matrix: np.ndarray
133
+
134
+ def __post_init__(self):
135
+ """Ensure matrix acts in sorted bases and typecheck sizes."""
136
+ left_size, right_size = self.matrix.shape
137
+ assert left_size == self.left_space.num_dims
138
+ assert right_size == self.right_space.num_dims
139
+
140
+ def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
141
+ """Describes the action of the operator on vectors."""
142
+ if x not in self.left_space:
143
+ raise TypeError(f"{x=} not in {self.left_space=}.")
144
+ if y not in self.right_space:
145
+ raise TypeError(f"{y=} not in {self.right_space=}.")
146
+ return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
147
+
148
+ @classmethod
149
+ def from_action(
150
+ cls,
151
+ left_space: VectorSpaceWithBasis,
152
+ right_space: VectorSpaceWithBasis,
153
+ action: Callable[[BasisDirection, BasisDirection], float],
154
+ ) -> "ScalarBilinear":
155
+ """from_action(l, r)(action) creates a ScalarBilinear."""
156
+
157
+ matrix = np.zeros((left_space.num_dims, right_space.num_dims))
158
+ for i, left_direction in enumerate(left_space.basis):
159
+ for j, right_direction in enumerate(right_space.basis):
160
+ matrix[i, j] = action(left_direction, right_direction)
161
+
162
+ return ScalarBilinear(left_space, right_space, matrix)
craft/vectorspace_fns_test.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for vectorspace_fns."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+ from tracr.craft import bases
21
+ from tracr.craft import tests_common
22
+ from tracr.craft import vectorspace_fns as vs_fns
23
+
24
+
25
+ class LinearTest(tests_common.VectorFnTestCase):
26
+
27
+ def test_identity_from_matrix(self):
28
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
29
+ f = vs_fns.Linear(vs, vs, np.eye(3))
30
+ for v in vs.basis_vectors():
31
+ self.assertEqual(f(v), v)
32
+
33
+ def test_identity_from_action(self):
34
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
35
+ f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction)
36
+ for v in vs.basis_vectors():
37
+ self.assertEqual(f(v), v)
38
+
39
+ def test_nonidentiy(self):
40
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
41
+ a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
42
+ b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
43
+
44
+ f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]]))
45
+
46
+ self.assertEqual(
47
+ f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7])))
48
+ self.assertEqual(
49
+ f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1])))
50
+
51
+ def test_different_vector_spaces(self):
52
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
53
+ vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
54
+ a, b = vs1.basis_vectors()
55
+ c, d = vs2.basis_vectors()
56
+
57
+ f = vs_fns.Linear(vs1, vs2, np.eye(2))
58
+
59
+ self.assertEqual(f(a), c)
60
+ self.assertEqual(f(b), d)
61
+
62
+ def test_combining_linear_functions_with_different_input(self):
63
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
64
+ vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
65
+ vs = bases.direct_sum(vs1, vs2)
66
+ a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
67
+ b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
68
+ c = vs.vector_from_basis_direction(bases.BasisDirection("c"))
69
+ d = vs.vector_from_basis_direction(bases.BasisDirection("d"))
70
+
71
+ f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]]))
72
+ f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]]))
73
+ f3 = vs_fns.Linear.combine_in_parallel([f1, f2])
74
+
75
+ self.assertEqual(
76
+ f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0])))
77
+ self.assertEqual(
78
+ f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0])))
79
+ self.assertEqual(
80
+ f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0])))
81
+ self.assertEqual(
82
+ f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0])))
83
+
84
+ def test_combining_linear_functions_with_same_input(self):
85
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
86
+ a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
87
+ b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
88
+
89
+ f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]]))
90
+ f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]]))
91
+ f3 = vs_fns.Linear.combine_in_parallel([f1, f2])
92
+
93
+ self.assertEqual(
94
+ f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1])))
95
+ self.assertEqual(
96
+ f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0])))
97
+ self.assertEqual(f3(a), f1(a) + f2(a))
98
+ self.assertEqual(f3(b), f1(b) + f2(b))
99
+
100
+
101
+ class ProjectionTest(tests_common.VectorFnTestCase):
102
+
103
+ def test_projection_to_larger_space(self):
104
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
105
+ vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
106
+ a1, b1 = vs1.basis_vectors()
107
+ a2, b2, _, _ = vs2.basis_vectors()
108
+
109
+ f = vs_fns.project(vs1, vs2)
110
+
111
+ self.assertEqual(f(a1), a2)
112
+ self.assertEqual(f(b1), b2)
113
+
114
+ def test_projection_to_smaller_space(self):
115
+ vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
116
+ vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
117
+ a1, b1, c1, d1 = vs1.basis_vectors()
118
+ a2, b2 = vs2.basis_vectors()
119
+
120
+ f = vs_fns.project(vs1, vs2)
121
+
122
+ self.assertEqual(f(a1), a2)
123
+ self.assertEqual(f(b1), b2)
124
+ self.assertEqual(f(c1), vs2.null_vector())
125
+ self.assertEqual(f(d1), vs2.null_vector())
126
+
127
+
128
+ class ScalarBilinearTest(parameterized.TestCase):
129
+
130
+ def test_identity_matrix(self):
131
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
132
+ a, b = vs.basis_vectors()
133
+
134
+ f = vs_fns.ScalarBilinear(vs, vs, np.eye(2))
135
+
136
+ self.assertEqual(f(a, a), 1)
137
+ self.assertEqual(f(a, b), 0)
138
+ self.assertEqual(f(b, a), 0)
139
+ self.assertEqual(f(b, b), 1)
140
+
141
+ def test_identity_from_action(self):
142
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
143
+ a, b = vs.basis_vectors()
144
+
145
+ f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y))
146
+
147
+ self.assertEqual(f(a, a), 1)
148
+ self.assertEqual(f(a, b), 0)
149
+ self.assertEqual(f(b, a), 0)
150
+ self.assertEqual(f(b, b), 1)
151
+
152
+ def test_non_identity(self):
153
+ vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
154
+ a, b = vs.basis_vectors()
155
+
156
+ f = vs_fns.ScalarBilinear.from_action(vs, vs,
157
+ lambda x, y: int(x.name == "a"))
158
+
159
+ self.assertEqual(f(a, a), 1)
160
+ self.assertEqual(f(a, b), 1)
161
+ self.assertEqual(f(b, a), 0)
162
+ self.assertEqual(f(b, b), 0)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ absltest.main()
examples/Visualize_Tracr_Models.ipynb ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "99FBiGH7bsfn"
7
+ },
8
+ "source": [
9
+ "# Compiling \u0026 Visualizing Tracr Models\n",
10
+ "\n",
11
+ "This notebook demonstrates how to compile a tracr model and provides some tools visualize the model's residual stream or layer outputs for a given input sequence."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {
18
+ "id": "qm-PM1PEawCx"
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "#@title Imports\n",
23
+ "import jax\n",
24
+ "import numpy as np\n",
25
+ "import matplotlib.pyplot as plt\n",
26
+ "\n",
27
+ "# The default of float16 can lead to discrepancies between outputs of\n",
28
+ "# the compiled model and the RASP program.\n",
29
+ "jax.config.update('jax_default_matmul_precision', 'float32')\n",
30
+ "\n",
31
+ "from tracr.compiler import compiling\n",
32
+ "from tracr.compiler import lib\n",
33
+ "from tracr.rasp import rasp"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {
40
+ "id": "HtOAc_yWawFR"
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "#@title Plotting functions\n",
45
+ "def tidy_label(label, value_width=5):\n",
46
+ " if ':' in label:\n",
47
+ " label, value = label.split(':')\n",
48
+ " else:\n",
49
+ " value = ''\n",
50
+ " return label + f\":{value:\u003e{value_width}}\"\n",
51
+ "\n",
52
+ "\n",
53
+ "def add_residual_ticks(model, value_width=5, x=False, y=True):\n",
54
+ " if y:\n",
55
+ " plt.yticks(\n",
56
+ " np.arange(len(model.residual_labels))+0.5, \n",
57
+ " [tidy_label(l, value_width=value_width)\n",
58
+ " for l in model.residual_labels], \n",
59
+ " family='monospace',\n",
60
+ " fontsize=20,\n",
61
+ " )\n",
62
+ " if x:\n",
63
+ " plt.xticks(\n",
64
+ " np.arange(len(model.residual_labels))+0.5, \n",
65
+ " [tidy_label(l, value_width=value_width)\n",
66
+ " for l in model.residual_labels], \n",
67
+ " family='monospace',\n",
68
+ " rotation=90,\n",
69
+ " fontsize=20,\n",
70
+ " )\n",
71
+ "\n",
72
+ "\n",
73
+ "def plot_computation_trace(model,\n",
74
+ " input_labels,\n",
75
+ " residuals_or_outputs,\n",
76
+ " add_input_layer=False,\n",
77
+ " figsize=(12, 9)):\n",
78
+ " fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)\n",
79
+ " value_width = max(map(len, map(str, input_labels))) + 1\n",
80
+ "\n",
81
+ " for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):\n",
82
+ " plt.sca(ax)\n",
83
+ " plt.pcolormesh(layer[0].T, vmin=0, vmax=1)\n",
84
+ " if i == 0:\n",
85
+ " add_residual_ticks(model, value_width=value_width)\n",
86
+ " plt.xticks(\n",
87
+ " np.arange(len(input_labels))+0.5,\n",
88
+ " input_labels,\n",
89
+ " rotation=90,\n",
90
+ " fontsize=20,\n",
91
+ " )\n",
92
+ " if add_input_layer and i == 0:\n",
93
+ " title = 'Input'\n",
94
+ " else:\n",
95
+ " layer_no = i - 1 if add_input_layer else i\n",
96
+ " layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'\n",
97
+ " title = f'{layer_type} {layer_no // 2 + 1}'\n",
98
+ " plt.title(title, fontsize=20)\n",
99
+ "\n",
100
+ "\n",
101
+ "def plot_residuals_and_input(model, inputs, figsize=(12, 9)):\n",
102
+ " \"\"\"Applies model to inputs, and plots the residual stream at each layer.\"\"\"\n",
103
+ " model_out = assembled_model.apply(inputs)\n",
104
+ " residuals = np.concatenate([model_out.input_embeddings[None, ...],\n",
105
+ " model_out.residuals], axis=0)\n",
106
+ " plot_computation_trace(\n",
107
+ " model=model,\n",
108
+ " input_labels=inputs,\n",
109
+ " residuals_or_outputs=residuals,\n",
110
+ " add_input_layer=True,\n",
111
+ " figsize=figsize)\n",
112
+ "\n",
113
+ "\n",
114
+ "def plot_layer_outputs(model, inputs, figsize=(12, 9)):\n",
115
+ " \"\"\"Applies model to inputs, and plots the outputs of each layer.\"\"\"\n",
116
+ " model_out = assembled_model.apply(inputs)\n",
117
+ " plot_computation_trace(\n",
118
+ " model=model,\n",
119
+ " input_labels=inputs,\n",
120
+ " residuals_or_outputs=model_out.layer_outputs,\n",
121
+ " add_input_layer=False,\n",
122
+ " figsize=figsize)\n"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {
129
+ "cellView": "form",
130
+ "id": "8hV0nv_ISmhM"
131
+ },
132
+ "outputs": [],
133
+ "source": [
134
+ "#@title Define RASP programs\n",
135
+ "def get_program(program_name, max_seq_len):\n",
136
+ " \"\"\"Returns RASP program and corresponding token vocabulary.\"\"\"\n",
137
+ " if program_name == \"length\":\n",
138
+ " vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
139
+ " program = lib.make_length()\n",
140
+ " elif program_name == \"frac_prevs\":\n",
141
+ " vocab = {\"a\", \"b\", \"c\", \"x\"}\n",
142
+ " program = lib.make_frac_prevs((rasp.tokens == \"x\").named(\"is_x\"))\n",
143
+ " elif program_name == \"dyck-2\":\n",
144
+ " vocab = {\"(\", \")\", \"{\", \"}\"}\n",
145
+ " program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\"])\n",
146
+ " elif program_name == \"dyck-3\":\n",
147
+ " vocab = {\"(\", \")\", \"{\", \"}\", \"[\", \"]\"}\n",
148
+ " program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\", \"[]\"])\n",
149
+ " elif program_name == \"sort\":\n",
150
+ " vocab = {1, 2, 3, 4, 5}\n",
151
+ " program = lib.make_sort(\n",
152
+ " rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)\n",
153
+ " elif program_name == \"sort_unique\":\n",
154
+ " vocab = {1, 2, 3, 4, 5}\n",
155
+ " program = lib.make_sort_unique(rasp.tokens, rasp.tokens)\n",
156
+ " elif program_name == \"hist\":\n",
157
+ " vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
158
+ " program = lib.make_hist()\n",
159
+ " elif program_name == \"sort_freq\":\n",
160
+ " vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
161
+ " program = lib.make_sort_freq(max_seq_len=max_seq_len)\n",
162
+ " elif program_name == \"pair_balance\":\n",
163
+ " vocab = {\"(\", \")\"}\n",
164
+ " program = lib.make_pair_balance(\n",
165
+ " sop=rasp.tokens, open_token=\"(\", close_token=\")\")\n",
166
+ " else:\n",
167
+ " raise NotImplementedError(f\"Program {program_name} not implemented.\")\n",
168
+ " return program, vocab"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {
175
+ "id": "L_m_ufaua9ri"
176
+ },
177
+ "outputs": [],
178
+ "source": [
179
+ "#@title: Assemble model\n",
180
+ "program_name = \"sort_unique\" #@param [\"length\", \"frac_prevs\", \"dyck-2\", \"dyck-3\", \"sort\", \"sort_unique\", \"hist\", \"sort_freq\", \"pair_balance\"]\n",
181
+ "max_seq_len = 5 #@param {label: \"Test\", type: \"integer\"}\n",
182
+ "\n",
183
+ "program, vocab = get_program(program_name=program_name,\n",
184
+ " max_seq_len=max_seq_len)\n",
185
+ "\n",
186
+ "print(f\"Compiling...\")\n",
187
+ "print(f\" Program: {program_name}\")\n",
188
+ "print(f\" Input vocabulary: {vocab}\")\n",
189
+ "print(f\" Context size: {max_seq_len}\")\n",
190
+ "\n",
191
+ "assembled_model = compiling.compile_rasp_to_model(\n",
192
+ " program=program,\n",
193
+ " vocab=vocab,\n",
194
+ " max_seq_len=max_seq_len,\n",
195
+ " causal=False,\n",
196
+ " compiler_bos=\"bos\",\n",
197
+ " compiler_pad=\"pad\",\n",
198
+ " mlp_exactness=100)\n",
199
+ "\n",
200
+ "print(\"Done.\")"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {
207
+ "id": "wtwiE-JiXF3F"
208
+ },
209
+ "outputs": [],
210
+ "source": [
211
+ "#@title Forward pass\n",
212
+ "assembled_model.apply([\"bos\", 3, 4, 1]).decoded"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {
219
+ "id": "RkEkVcEHa2gf"
220
+ },
221
+ "outputs": [],
222
+ "source": [
223
+ "#@title Plot residual stream\n",
224
+ "plot_residuals_and_input(\n",
225
+ " model=assembled_model,\n",
226
+ " inputs=[\"bos\", 3, 4, 1],\n",
227
+ " figsize=(10, 9)\n",
228
+ ")"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "metadata": {
235
+ "id": "8c4LakWHa4ey"
236
+ },
237
+ "outputs": [],
238
+ "source": [
239
+ "#@title Plot layer outputs\n",
240
+ "plot_layer_outputs(\n",
241
+ " model=assembled_model,\n",
242
+ " inputs = [\"bos\", 3, 4, 1],\n",
243
+ " figsize=(8, 9)\n",
244
+ ")"
245
+ ]
246
+ }
247
+ ],
248
+ "metadata": {
249
+ "colab": {
250
+ "private_outputs": true
251
+ },
252
+ "kernelspec": {
253
+ "display_name": "Python 3",
254
+ "name": "python3"
255
+ },
256
+ "language_info": {
257
+ "name": "python"
258
+ }
259
+ },
260
+ "nbformat": 4,
261
+ "nbformat_minor": 0
262
+ }
rasp/causal_eval.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """RASP Evaluator which applies causal masks to selectors."""
16
+
17
+ from typing import Sequence, Union
18
+
19
+ import numpy as np
20
+ from tracr.rasp import rasp
21
+
22
+
23
+ class CausalEvaluator(rasp.DefaultRASPEvaluator):
24
+ """Evaluates RASP with causal masking."""
25
+
26
+ def evaluate(
27
+ self, expr: rasp.RASPExpr, xs: Sequence[rasp.Value]
28
+ ) -> Union[Sequence[rasp.Value], rasp.SelectorValue]:
29
+ out = super().evaluate(expr, xs)
30
+
31
+ if not isinstance(expr, rasp.Selector):
32
+ return out
33
+
34
+ out = np.array(out)
35
+ causal_mask = np.tril(np.full(out.shape, 1))
36
+ return np.logical_and(causal_mask, out).tolist()
37
+
38
+
39
+ evaluate = CausalEvaluator().evaluate
rasp/causal_eval_test.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for causal_eval."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+
20
+ from tracr.rasp import causal_eval
21
+ from tracr.rasp import rasp
22
+
23
+
24
+ class CausalEvalTest(parameterized.TestCase):
25
+
26
+ @parameterized.named_parameters(
27
+ dict(
28
+ testcase_name="constant_selector_3x3_1",
29
+ program=rasp.ConstantSelector([
30
+ [True, True, True],
31
+ [True, True, True],
32
+ [True, True, True],
33
+ ]),
34
+ input_sequence=[True, True, True],
35
+ expected_output=[
36
+ [True, False, False],
37
+ [True, True, False],
38
+ [True, True, True],
39
+ ]),
40
+ dict(
41
+ testcase_name="constant_selector_3x3_2",
42
+ program=rasp.ConstantSelector([
43
+ [True, True, True],
44
+ [False, True, True],
45
+ [True, False, True],
46
+ ]),
47
+ input_sequence=[True, True, True],
48
+ expected_output=[
49
+ [True, False, False],
50
+ [False, True, False],
51
+ [True, False, True],
52
+ ]))
53
+ def test_evaluations(self, program, input_sequence, expected_output):
54
+ self.assertListEqual(
55
+ causal_eval.evaluate(program, input_sequence),
56
+ expected_output,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ absltest.main()
rasp/rasp.py ADDED
@@ -0,0 +1,932 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """RASP program objects.
16
+
17
+ Every object in the RASP language is a function.
18
+
19
+ The most important type is S-Op, which is a function list[Value] -> list[Value].
20
+
21
+ An S-Op represents a state inside the residual stream of the transformer.
22
+ Therefore, any RASP program that represents a transformer computation must
23
+ define a final S-Op that represents the state of the residual stream at the
24
+ end of the computation. In particular, given an S-Op `x`,
25
+ `x([1, 2, 3])` represents something like the state of the residual stream
26
+ at location `x` when the transformer is fed [1, 2, 3] as input.
27
+
28
+ A secondary (but still important) type is Selector, which is a function
29
+ list[Value] -> list[list[bool]]. Given a Selector `sel`, sel([1, 2, 3])
30
+ represents something like an attention matrix in the transformer.
31
+
32
+ For a full reference on RASP, see https://arxiv.org/abs/2106.06981.
33
+ """
34
+
35
+ import abc
36
+ import collections.abc
37
+ import copy
38
+ import enum
39
+ import functools
40
+ import itertools
41
+ from typing import (Any, Callable, Generic, Mapping, Optional, Protocol,
42
+ Sequence, TypeVar, Union)
43
+ from absl import logging
44
+
45
+ import numpy as np
46
+
47
+ SelectorValue = list[list[bool]]
48
+ NumericValue = Union[int, float]
49
+ Value = Union[None, int, float, str, bool]
50
+ VT = TypeVar("VT", bound=Value)
51
+ RASPExprT = TypeVar("RASPExprT", bound="RASPExpr")
52
+ SOpT = TypeVar("SOpT", bound="SOp")
53
+ T = TypeVar("T")
54
+
55
+ _NAME_KEY = "name"
56
+ _ENCODING_KEY = "encoding"
57
+
58
+ # These are run on every expression when it's initialised.
59
+ # Add your own annotators to this dict to add custom default annotations.
60
+ #
61
+ # For example, DEFAULT_ANNOTATORS['foo'] will provide the default value for
62
+ # expr.annotations['foo]. The annotator will get called lazily the first time
63
+ # that key is accessed.
64
+ #
65
+ # See the `default_name` annotator for a full example.
66
+ DEFAULT_ANNOTATORS: dict[str, "Annotator"] = {}
67
+
68
+
69
+ class Annotator(Protocol):
70
+
71
+ def __call__(self, expr: "RASPExpr") -> Any:
72
+ """What annotation to add to `expr`."""
73
+
74
+
75
+ class _Annotations(collections.abc.Mapping):
76
+ """Holds the expression's annotations.
77
+
78
+ It's immutable to the user, but will attempt to generate default values
79
+ lazily when missing keys are requested.
80
+ """
81
+
82
+ def __init__(self, expr, **kwargs: Any):
83
+ self._expr = expr
84
+ self._inner_dict: dict[str, Any] = {**kwargs}
85
+
86
+ def __getitem__(self, key: str) -> Any:
87
+ if key not in self._inner_dict:
88
+ if key not in DEFAULT_ANNOTATORS:
89
+ raise KeyError(
90
+ f"No annotation exists for key '{key}'. "
91
+ f"Available keys: {list(*self.keys(), *DEFAULT_ANNOTATORS.keys())}")
92
+ self._inner_dict[key] = DEFAULT_ANNOTATORS[key](self._expr)
93
+
94
+ return self._inner_dict[key]
95
+
96
+ def __iter__(self):
97
+ return iter(self._inner_dict)
98
+
99
+ def __len__(self):
100
+ return len(self._inner_dict)
101
+
102
+
103
+ class RASPExpr(abc.ABC):
104
+ """A class distinguishing RASP expressions from other objects."""
105
+ _ids = itertools.count(1)
106
+
107
+ def __init__(self):
108
+ self._annotations: Mapping[str, Any] = _Annotations(self)
109
+
110
+ @abc.abstractmethod
111
+ def __call__(self,
112
+ xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]:
113
+ """Evaluates the RASPExpr using the standard evaluator."""
114
+
115
+ @property
116
+ def annotations(self) -> Mapping[str, Any]:
117
+ """The annotations of this expression instance."""
118
+ return self._annotations
119
+
120
+ @annotations.setter
121
+ def annotations(self, annotations: Mapping[str, Any]):
122
+ self._annotations = _Annotations(self, **annotations)
123
+
124
+ @property
125
+ def name(self) -> str:
126
+ """The name of this expression."""
127
+ return self.annotations[_NAME_KEY]
128
+
129
+ @property
130
+ @abc.abstractmethod
131
+ def children(self) -> Sequence["RASPExpr"]:
132
+ """Direct dependencies of this expression."""
133
+
134
+ @functools.cached_property
135
+ def unique_id(self):
136
+ """A unique id for every expression instance."""
137
+ return next(self._ids)
138
+
139
+ def copy(self: RASPExprT) -> RASPExprT:
140
+ """Returns a shallow copy of this RASPExpr with a new ID."""
141
+ return copy.copy(self)
142
+
143
+ @property
144
+ def label(self) -> str:
145
+ return f"{self.name}_{self.unique_id}"
146
+
147
+ def named(self: RASPExprT, name: str) -> RASPExprT:
148
+ """Convenience method for adding a name."""
149
+ return annotate(self, name=name)
150
+
151
+ def annotated(self: RASPExprT, **annotations) -> RASPExprT:
152
+ """Convenience method for adding annotations."""
153
+ return annotate(self, **annotations)
154
+
155
+
156
+ def annotate(expr: RASPExprT, **annotations) -> RASPExprT:
157
+ """Creates a new expr with added annotations."""
158
+ new = expr.copy()
159
+ # Note that new annotations will overwrite existing ones with matching keys.
160
+ new.annotations = {**expr.annotations, **annotations}
161
+ return new
162
+
163
+
164
+ ### S-Ops.
165
+
166
+
167
+ class SOp(RASPExpr):
168
+ """A Sequence Operation."""
169
+
170
+ def __call__(self, xs: Sequence[Value]) -> Sequence[Value]:
171
+ return evaluate(self, xs) # pytype: disable=bad-return-type
172
+
173
+ # Allow construction of SOps using numeric operators with constant values.
174
+ # Note: if inheriting SOp by a dataclass, make sure to disable eq and order,
175
+ # as they will override these.
176
+
177
+ def __lt__(self, other: Value) -> "SOp":
178
+ """self < other."""
179
+ return Map(lambda x: x < other, self)
180
+
181
+ def __le__(self, other: Value) -> "SOp":
182
+ """self <= other."""
183
+ return Map(lambda x: x <= other, self)
184
+
185
+ def __eq__(self, other: Value) -> "SOp":
186
+ """self == other."""
187
+ return Map(lambda x: x == other, self)
188
+
189
+ def __ne__(self, other: Value) -> "SOp":
190
+ """self != other."""
191
+ return Map(lambda x: x != other, self)
192
+
193
+ def __gt__(self, other: Value) -> "SOp":
194
+ """self > other."""
195
+ return Map(lambda x: x > other, self)
196
+
197
+ def __ge__(self, other: Value) -> "SOp":
198
+ """self >= other."""
199
+ return Map(lambda x: x >= other, self)
200
+
201
+ def __add__(self, other: Union["SOp", Value]) -> "SOp":
202
+ """self + other."""
203
+ if isinstance(other, SOp):
204
+ return SequenceMap(lambda x, y: x + y, self, other)
205
+ return Map(lambda x: x + other, self)
206
+
207
+ def __radd__(self, other: Union["SOp", Value]) -> "SOp":
208
+ """other + self."""
209
+ if isinstance(other, SOp):
210
+ return SequenceMap(lambda x, y: x + y, other, self)
211
+ return Map(lambda x: other + x, self)
212
+
213
+ def __sub__(self, other: Union["SOp", NumericValue]) -> "SOp":
214
+ """self - other."""
215
+ if isinstance(other, SOp):
216
+ return SequenceMap(lambda x, y: x - y, self, other)
217
+ return Map(lambda x: x - other, self)
218
+
219
+ def __rsub__(self, other: Union["SOp", NumericValue]) -> "SOp":
220
+ """other - self."""
221
+ if isinstance(other, SOp):
222
+ return SequenceMap(lambda x, y: x - y, other, self)
223
+ return Map(lambda x: other - x, self)
224
+
225
+ def __mul__(self, other: Union["SOp", NumericValue]) -> "SOp":
226
+ """self * other."""
227
+ if isinstance(other, SOp):
228
+ return SequenceMap(lambda x, y: x * y, self, other)
229
+ return Map(lambda x: x * other, self)
230
+
231
+ def __rmul__(self, other: Union["SOp", NumericValue]) -> "SOp":
232
+ """other * self."""
233
+ if isinstance(other, SOp):
234
+ return SequenceMap(lambda x, y: x * y, other, self)
235
+ return Map(lambda x: other * x, self)
236
+
237
+ def __truediv__(self, other: Union["SOp", NumericValue]) -> "SOp":
238
+ """self / other."""
239
+ if isinstance(other, SOp):
240
+ return SequenceMap(lambda x, y: x / y, self, other)
241
+ return Map(lambda x: x / other, self)
242
+
243
+ def __rtruediv__(self, other: Union["SOp", NumericValue]) -> "SOp":
244
+ """other / self."""
245
+ if isinstance(other, SOp):
246
+ return SequenceMap(lambda x, y: x / y, other, self)
247
+ return Map(lambda x: other / x, self)
248
+
249
+ def __invert__(self) -> "SOp":
250
+ return Map(lambda x: not x, self)
251
+
252
+ def __and__(self, other: Union["SOp", NumericValue]) -> "SOp":
253
+ """self & other."""
254
+ if isinstance(other, SOp):
255
+ return SequenceMap(lambda x, y: x and y, self, other)
256
+ return Map(lambda x: x and other, self)
257
+
258
+ def __or__(self, other: Union["SOp", NumericValue]) -> "SOp":
259
+ """self | other."""
260
+ if isinstance(other, SOp):
261
+ return SequenceMap(lambda x, y: x or y, self, other)
262
+ return Map(lambda x: x or other, self)
263
+
264
+ def __rand__(self, other: Union["SOp", NumericValue]) -> "SOp":
265
+ """other & self."""
266
+ if isinstance(other, SOp):
267
+ return SequenceMap(lambda x, y: x and y, other, self)
268
+ return Map(lambda x: other and x, self)
269
+
270
+ def __ror__(self, other: Union["SOp", NumericValue]) -> "SOp":
271
+ """other | self."""
272
+ if isinstance(other, SOp):
273
+ return SequenceMap(lambda x, y: x or y, other, self)
274
+ return Map(lambda x: x or other, self)
275
+
276
+
277
+ class TokensType(SOp):
278
+ """Primitive SOp returning the original input tokens."""
279
+
280
+ @property
281
+ def children(self) -> Sequence[RASPExpr]:
282
+ return []
283
+
284
+ @property
285
+ def label(self) -> str:
286
+ return "tokens"
287
+
288
+ def __repr__(self):
289
+ return "tokens"
290
+
291
+
292
+ class IndicesType(SOp):
293
+ """Primitive SOp returning the position index at each token."""
294
+
295
+ @property
296
+ def children(self) -> Sequence[RASPExpr]:
297
+ return []
298
+
299
+ @property
300
+ def label(self) -> str:
301
+ return "indices"
302
+
303
+ def __repr__(self):
304
+ return "indices"
305
+
306
+
307
+ class LengthType(SOp):
308
+ """Primitive SOp returning the total length of the input."""
309
+
310
+ @property
311
+ def children(self) -> Sequence[RASPExpr]:
312
+ return []
313
+
314
+ @property
315
+ def label(self) -> str:
316
+ return "length"
317
+
318
+ def __repr__(self):
319
+ return "length"
320
+
321
+
322
+ tokens = TokensType()
323
+ indices = IndicesType()
324
+ length = LengthType()
325
+
326
+
327
+ class Map(SOp):
328
+ """SOp that evaluates the function elementwise on the input SOp.
329
+
330
+ Map(lambda x: x + 1, tokens).eval([1, 2, 3]) == [2, 3, 4]
331
+ """
332
+
333
+ def __init__(self, f: Callable[[Value], Value], inner: SOp):
334
+ super().__init__()
335
+ self.f = f
336
+ self.inner = inner
337
+
338
+ assert isinstance(self.inner, SOp)
339
+ assert callable(self.f) and not isinstance(self.f, RASPExpr)
340
+
341
+ if isinstance(self.inner, Map):
342
+ # Combine the functions into just one.
343
+ inner_f = self.inner.f
344
+ self.f = lambda t: f(inner_f(t))
345
+ self.inner = self.inner.inner
346
+
347
+ @property
348
+ def children(self) -> Sequence[RASPExpr]:
349
+ return [self.inner]
350
+
351
+
352
+ class SequenceMap(SOp):
353
+ """SOp that evaluates the function elementwise on the two given SOp's.
354
+
355
+ SequenceMap(lambda x, y: x - y, length, tokens).eval([1, 2, 3]) == [2, 1, 0]
356
+ """
357
+
358
+ def __init__(self, f: Callable[[Value, Value], Value], fst: SOp, snd: SOp):
359
+ super().__init__()
360
+
361
+ if fst == snd:
362
+ logging.warning("Creating a SequenceMap with both inputs being the same "
363
+ "SOp is discouraged. You should use a Map instead.")
364
+
365
+ self.f = f
366
+ self.fst = fst
367
+ self.snd = snd
368
+ assert isinstance(self.fst, SOp)
369
+ assert isinstance(self.snd, SOp)
370
+ assert callable(self.f) and not isinstance(self.f, RASPExpr)
371
+
372
+ @property
373
+ def children(self) -> Sequence[RASPExpr]:
374
+ return [self.fst, self.snd]
375
+
376
+
377
+ class LinearSequenceMap(SequenceMap):
378
+ """SOp that evaluates a linear function elementwise on the two given SOp's."""
379
+
380
+ def __init__(self, fst: SOp, snd: SOp, fst_fac: float, snd_fac: float):
381
+ super().__init__(fst=fst, snd=snd, f=lambda x, y: fst_fac * x + snd_fac * y)
382
+ self.fst_fac = fst_fac
383
+ self.snd_fac = snd_fac
384
+
385
+
386
+ class Full(SOp):
387
+ """A SOp evaluating to [fill]*len(input_values)."""
388
+
389
+ def __init__(self, fill: Value):
390
+ super().__init__()
391
+ self.fill = fill
392
+
393
+ @property
394
+ def children(self) -> Sequence[RASPExpr]:
395
+ return []
396
+
397
+
398
+ def sop_not(sop: SOp) -> SOp:
399
+ return Map(lambda t: not t, sop)
400
+
401
+
402
+ class ConstantSOp(SOp, Generic[VT]):
403
+ """A constant S-Op for testing purposes."""
404
+
405
+ def __init__(self, value: Sequence[VT], check_length: bool = True):
406
+ super().__init__()
407
+ self.value = value
408
+ self.check_length = check_length
409
+
410
+ @property
411
+ def children(self) -> Sequence[RASPExpr]:
412
+ return []
413
+
414
+
415
+ ### Selectors.
416
+
417
+
418
+ class Predicate(Protocol):
419
+
420
+ def __call__(self, key: Value, query: Value) -> bool:
421
+ """Applies the predicate."""
422
+
423
+
424
+ class Comparison(enum.Enum):
425
+ """A two-place boolean comparison predicate for use in Select."""
426
+ EQ = "=="
427
+ LT = "<"
428
+ LEQ = "<="
429
+ GT = ">"
430
+ GEQ = ">="
431
+ NEQ = "!="
432
+ TRUE = "True"
433
+ FALSE = "False"
434
+
435
+ def __call__(self, key: Value, query: Value) -> bool:
436
+ if key is None:
437
+ raise ValueError("key is None!")
438
+ if query is None:
439
+ raise ValueError("query is None!")
440
+ return _comparison_table[self](key, query)
441
+
442
+
443
+ _comparison_table = {
444
+ Comparison.EQ: lambda key, query: key == query,
445
+ Comparison.LT: lambda key, query: key < query,
446
+ Comparison.LEQ: lambda key, query: key <= query,
447
+ Comparison.GT: lambda key, query: key > query,
448
+ Comparison.GEQ: lambda key, query: key >= query,
449
+ Comparison.NEQ: lambda key, query: key != query,
450
+ Comparison.TRUE: lambda key, query: True,
451
+ Comparison.FALSE: lambda key, query: False,
452
+ }
453
+
454
+
455
+ class Selector(RASPExpr):
456
+ """RASP Selector. Represents something like an attention head's weights."""
457
+
458
+ def __call__(self, xs: Sequence[Value]) -> SelectorValue:
459
+ return evaluate(self, xs) # pytype: disable=bad-return-type
460
+
461
+ # Allow construction of Selector combinations using Python logical operators.
462
+ def __and__(self, other: "Selector") -> "Selector":
463
+ """self & other."""
464
+ return selector_and(self, other)
465
+
466
+ def __rand__(self, other: "Selector") -> "Selector":
467
+ """other & self."""
468
+ return selector_and(other, self)
469
+
470
+ def __or__(self, other: "Selector") -> "Selector":
471
+ """self | other."""
472
+ return selector_or(self, other)
473
+
474
+ def __ror__(self, other: "Selector") -> "Selector":
475
+ """other | self."""
476
+ return selector_or(other, self)
477
+
478
+ def __invert__(self) -> "Selector":
479
+ """~self."""
480
+ return selector_not(self)
481
+
482
+
483
+ class Select(Selector):
484
+ """Primitive that creates a Selector."""
485
+
486
+ def __init__(self, keys: SOp, queries: SOp, predicate: Predicate):
487
+ super().__init__()
488
+ self.keys = keys
489
+ self.queries = queries
490
+ self.predicate = predicate
491
+ assert isinstance(self.keys, SOp)
492
+ assert isinstance(self.queries, SOp)
493
+
494
+ @property
495
+ def children(self) -> Sequence[RASPExpr]:
496
+ return [self.keys, self.queries]
497
+
498
+
499
+ class ConstantSelector(Selector):
500
+ """A constant selector for testing purposes."""
501
+
502
+ def __init__(self, value: SelectorValue, check_length: bool = True):
503
+ super().__init__()
504
+ self.value = value
505
+ self.check_length = check_length
506
+
507
+ @property
508
+ def children(self) -> Sequence[RASPExpr]:
509
+ return []
510
+
511
+
512
+ class SelectorWidth(SOp):
513
+ """SelectorWidth primitive."""
514
+
515
+ def __init__(self, selector: Selector):
516
+ super().__init__()
517
+ self.selector = selector
518
+ assert isinstance(self.selector, Selector)
519
+
520
+ @property
521
+ def children(self) -> Sequence[RASPExpr]:
522
+ return [self.selector]
523
+
524
+
525
+ class SelectorAnd(Selector):
526
+ """Implements elementwise `and` between selectors."""
527
+
528
+ def __init__(self, fst: Selector, snd: Selector):
529
+ super().__init__()
530
+ self.fst = fst
531
+ self.snd = snd
532
+ assert isinstance(self.fst, Selector)
533
+ assert isinstance(self.snd, Selector)
534
+
535
+ @property
536
+ def children(self) -> Sequence[RASPExpr]:
537
+ return [self.fst, self.snd]
538
+
539
+
540
+ class SelectorOr(Selector):
541
+ """Implements elementwise `or` between selectors."""
542
+
543
+ def __init__(self, fst: Selector, snd: Selector):
544
+ super().__init__()
545
+ self.fst = fst
546
+ self.snd = snd
547
+ assert isinstance(self.fst, Selector)
548
+ assert isinstance(self.snd, Selector)
549
+
550
+ @property
551
+ def children(self) -> Sequence[RASPExpr]:
552
+ return [self.fst, self.snd]
553
+
554
+
555
+ class SelectorNot(Selector):
556
+ """Implements elementwise `not` on a selector."""
557
+
558
+ def __init__(self, inner: Selector):
559
+ self.inner = inner
560
+ super().__init__()
561
+ assert isinstance(self.inner, Selector)
562
+
563
+ @property
564
+ def children(self) -> Sequence[RASPExpr]:
565
+ return [self.inner]
566
+
567
+
568
+ def selector_not(
569
+ inner: Selector,
570
+ simplify: bool = True,
571
+ ) -> Selector:
572
+ """Returns a SelectorNot, or a Select if simplifying is possible."""
573
+ if simplify and isinstance(inner, Select):
574
+ predicate = lambda k, q: not inner.predicate(k, q)
575
+ return Select(inner.keys, inner.queries, predicate=predicate)
576
+
577
+ return SelectorNot(inner)
578
+
579
+
580
+ def selector_and(
581
+ fst: Selector,
582
+ snd: Selector,
583
+ simplify: bool = True,
584
+ ) -> Selector:
585
+ """Returns a SelectorAnd, or a Select if simplifying is possible."""
586
+ if simplify and isinstance(fst, Select) and isinstance(snd, Select):
587
+ simplified = _attempt_simplify(fst, snd, lambda l, r: l and r)
588
+ if simplified:
589
+ return simplified
590
+
591
+ return SelectorAnd(fst, snd)
592
+
593
+
594
+ def selector_or(
595
+ fst: Selector,
596
+ snd: Selector,
597
+ simplify: bool = True,
598
+ ) -> Selector:
599
+ """Returns a SelectorOr, or a Select if simplifying is possible."""
600
+ if simplify and isinstance(fst, Select) and isinstance(snd, Select):
601
+ simplified = _attempt_simplify(fst, snd, lambda l, r: l or r)
602
+ if simplified:
603
+ return simplified
604
+
605
+ return SelectorOr(fst, snd)
606
+
607
+
608
+ def _attempt_simplify(
609
+ fst: Select,
610
+ snd: Select,
611
+ combine: Callable[[bool, bool], bool],
612
+ ) -> Optional[Select]:
613
+ """Simplifies two Selects if possible.
614
+
615
+ If two Selects in a compound Selector have matching keys and queries, they can
616
+ be simplified into one Select with a compound predicate:
617
+
618
+ lambda k,q: combine(fst.predicate(k,q), snd.predicate(k,q))
619
+
620
+ This function returns a Select with this predicate if possible,
621
+ and None otherwise.
622
+
623
+ A Full SOp in a key or query position is a special case that always matches
624
+ any SOp in the corresponding position in the other selector. In that case,
625
+ we bake in the fill value into the corresponding Select's predicate before
626
+ combining. This allows us to use the other SOp as the input to the simplified
627
+ Select.
628
+
629
+ Args:
630
+ fst: the first Select.
631
+ snd: the second Select.
632
+ combine: how to combine the outputs of the individual predicates.
633
+
634
+ Returns:
635
+ A combined Select, if possible.
636
+ """
637
+ fst_predicate = fst.predicate
638
+ snd_predicate = snd.predicate
639
+ common_keys = None
640
+ common_queries = None
641
+
642
+ if isinstance(fst.keys, Full):
643
+ common_keys = snd.keys
644
+ # We pass the predicate in as a default arg to avoid unintended recursion.
645
+ fst_predicate = lambda key, query, p=fst_predicate: p(fst.keys.fill, query)
646
+ if isinstance(snd.keys, Full):
647
+ common_keys = fst.keys
648
+ snd_predicate = lambda key, query, p=snd_predicate: p(snd.keys.fill, query)
649
+ if isinstance(fst.queries, Full):
650
+ common_queries = snd.queries
651
+ fst_predicate = lambda key, query, p=fst_predicate: p(key, fst.queries.fill)
652
+ if isinstance(snd.queries, Full):
653
+ common_queries = fst.queries
654
+ snd_predicate = lambda key, query, p=snd_predicate: p(key, snd.queries.fill)
655
+ if fst.keys is snd.keys:
656
+ common_keys = fst.keys
657
+ if fst.queries is snd.queries:
658
+ common_queries = fst.queries
659
+
660
+ if not common_keys or not common_queries:
661
+ return None
662
+
663
+ def predicate(key, query):
664
+ return combine(fst_predicate(key, query), snd_predicate(key, query))
665
+
666
+ return Select(common_keys, common_queries, predicate=predicate)
667
+
668
+
669
+ class Aggregate(SOp, Generic[VT]):
670
+ """Aggregate primitive."""
671
+
672
+ def __init__(self,
673
+ selector: Selector,
674
+ sop: SOp,
675
+ default: Optional[VT] = None):
676
+ """Initialises. The default is used where nothing is selected."""
677
+ super().__init__()
678
+ self.selector = selector
679
+ self.sop = sop
680
+ self.default = default
681
+ assert isinstance(self.selector, Selector)
682
+ assert isinstance(self.sop, SOp)
683
+ assert (self.default is None or isinstance(self.default,
684
+ (str, float, bool, int)))
685
+
686
+ @property
687
+ def children(self) -> Sequence[RASPExpr]:
688
+ return [self.selector, self.sop]
689
+
690
+
691
+ ### SOp encodings.
692
+
693
+
694
+ class Encoding(enum.Enum):
695
+ """The encoding used by a SOp. Only number-valued SOps support numerical."""
696
+ CATEGORICAL = "categorical"
697
+ NUMERICAL = "numerical"
698
+
699
+
700
+ def numerical(sop: SOpT) -> SOpT:
701
+ return annotate(sop, encoding=Encoding.NUMERICAL)
702
+
703
+
704
+ def categorical(sop: SOpT) -> SOpT:
705
+ return annotate(sop, encoding=Encoding.CATEGORICAL)
706
+
707
+
708
+ def get_encoding(sop: SOp) -> Encoding:
709
+ return sop.annotations["encoding"]
710
+
711
+
712
+ def is_numerical(sop: SOp) -> bool:
713
+ """Check if the SOp is numerically encoded."""
714
+ return get_encoding(sop) == Encoding.NUMERICAL
715
+
716
+
717
+ def is_categorical(sop: SOp) -> bool:
718
+ """Check if the SOp is categorically encoded."""
719
+ return get_encoding(sop) == Encoding.CATEGORICAL
720
+
721
+
722
+ def default_encoding(expr: RASPExpr) -> Optional[Encoding]:
723
+ """Adds an 'encoding' annotation, default is Categorical."""
724
+ if not isinstance(expr, SOp):
725
+ raise TypeError(f"expr {expr} is not a SOp.")
726
+
727
+ return Encoding.CATEGORICAL
728
+
729
+
730
+ DEFAULT_ANNOTATORS[_ENCODING_KEY] = default_encoding
731
+
732
+ ### naming.
733
+
734
+ # Subclasses must appear here before superclasses in order for
735
+ # the most specific entry to be used.
736
+
737
+ _default_name_by_class = {
738
+ # Primitives
739
+ TokensType: "tokens",
740
+ IndicesType: "indices",
741
+ LengthType: "length",
742
+ # SOps
743
+ LinearSequenceMap: "linear_sequence_map",
744
+ SequenceMap: "sequence_map",
745
+ Map: "map",
746
+ Full: "full",
747
+ ConstantSOp: "constant_sop",
748
+ SelectorWidth: "selector_width",
749
+ Aggregate: "aggregate",
750
+ SOp: "sop",
751
+ # Selectors
752
+ Select: "select",
753
+ SelectorAnd: "selector_and",
754
+ SelectorOr: "selector_or",
755
+ SelectorNot: "selector_not",
756
+ ConstantSelector: "constant_selector",
757
+ Selector: "selector",
758
+ }
759
+
760
+
761
+ def default_name(expr: RASPExpr) -> dict[str, str]:
762
+ for cls, name in _default_name_by_class.items():
763
+ if isinstance(expr, cls):
764
+ return name
765
+
766
+ raise NotImplementedError(f"{expr} was not given a default name!")
767
+
768
+
769
+ DEFAULT_ANNOTATORS[_NAME_KEY] = default_name
770
+
771
+ ### evaluation.
772
+
773
+
774
+ class RASPEvaluator(abc.ABC):
775
+ """ABC for RASP evaluators."""
776
+
777
+ @abc.abstractmethod
778
+ def evaluate(self, expr: RASPExpr,
779
+ xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]:
780
+ """Evaluates the RASP expression on input `xs`."""
781
+
782
+
783
+ class DefaultRASPEvaluator(abc.ABC):
784
+ """Default evaluator for RASP."""
785
+
786
+ def evaluate(self, expr: RASPExpr,
787
+ xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]:
788
+ """Evaluates the RASP expression on input `xs`."""
789
+ return self._eval_fn_by_expr_type[type(expr)](expr, xs)
790
+
791
+ def __init__(self):
792
+ self._eval_fn_by_expr_type = {
793
+ # Primitives
794
+ TokensType: self.eval_tokens,
795
+ IndicesType: self.eval_indices,
796
+ LengthType: self.eval_length,
797
+ # SOps
798
+ LinearSequenceMap: self.eval_sequence_map,
799
+ SequenceMap: self.eval_sequence_map,
800
+ Map: self.eval_map,
801
+ Full: self.eval_full,
802
+ ConstantSOp: self.eval_constant_sop,
803
+ SelectorWidth: self.eval_selector_width,
804
+ Aggregate: self.eval_aggregate,
805
+ SOp: _raise_not_implemented,
806
+ # Selectors
807
+ Select: self.eval_select,
808
+ SelectorAnd: self.eval_selector_and,
809
+ SelectorOr: self.eval_selector_or,
810
+ SelectorNot: self.eval_selector_not,
811
+ ConstantSelector: self.eval_constant_selector,
812
+ Selector: _raise_not_implemented,
813
+ }
814
+
815
+ def eval_tokens(self, sop: TokensType,
816
+ xs: Sequence[Value]) -> Sequence[Value]:
817
+ del sop
818
+ return list(xs)
819
+
820
+ def eval_indices(self, sop: IndicesType,
821
+ xs: Sequence[Value]) -> Sequence[Value]:
822
+ del sop
823
+ return list(range(len(xs)))
824
+
825
+ def eval_length(self, sop: LengthType, xs: Sequence[Value]) -> Sequence[int]:
826
+ del sop
827
+ return [len(xs)] * len(xs)
828
+
829
+ def eval_sequence_map(self, sop: SequenceMap,
830
+ xs: Sequence[Value]) -> Sequence[Value]:
831
+ fst_values = self.evaluate(sop.fst, xs)
832
+ snd_values = self.evaluate(sop.snd, xs)
833
+ return [
834
+ sop.f(x, y) if None not in [x, y] else None
835
+ for x, y in zip(fst_values, snd_values)
836
+ ]
837
+
838
+ def eval_map(self, sop: Map, xs: Sequence[Value]) -> Sequence[Value]:
839
+ return [
840
+ sop.f(x) if x is not None else None
841
+ for x in self.evaluate(sop.inner, xs)
842
+ ]
843
+
844
+ def eval_full(self, sop: Full, xs: Sequence[Value]) -> Sequence[Value]:
845
+ return [sop.fill] * len(xs)
846
+
847
+ def eval_constant_sop(self, sop: ConstantSOp,
848
+ xs: Sequence[Value]) -> Sequence[Value]:
849
+ if sop.check_length and (len(xs) != len(sop.value)):
850
+ raise ValueError(
851
+ f"Constant len {len(sop.value)} doesn't match input len {len(xs)}.")
852
+ return sop.value
853
+
854
+ def eval_selector_width(self, sop: SelectorWidth,
855
+ xs: Sequence[Value]) -> Sequence[Value]:
856
+ selector_values = self.evaluate(sop.selector, xs)
857
+ return [sum(row) for row in selector_values]
858
+
859
+ def eval_aggregate(self, sop: Aggregate,
860
+ xs: Sequence[Value]) -> Sequence[Value]:
861
+ selector_value = self.evaluate(sop.selector, xs)
862
+ values = self.evaluate(sop.sop, xs)
863
+ default = sop.default
864
+
865
+ return [
866
+ _mean(_get_selected(row, values), default) for row in selector_value
867
+ ]
868
+
869
+ def eval_select(self, sel: Select, xs: Sequence[Value]) -> SelectorValue:
870
+ """Evaluates a Select on `xs`."""
871
+ key_values = self.evaluate(sel.keys, xs)
872
+ query_values = self.evaluate(sel.queries, xs)
873
+
874
+ key_len = len(key_values)
875
+ query_len = len(query_values)
876
+ out = np.zeros((query_len, key_len), dtype=bool).tolist()
877
+ for row, query in enumerate(query_values):
878
+ for col, key in enumerate(key_values):
879
+ out[row][col] = bool(sel.predicate(key, query))
880
+ return out
881
+
882
+ def eval_constant_selector(self, sel: ConstantSelector,
883
+ xs: Sequence[Value]) -> SelectorValue:
884
+ if sel.check_length and (len(xs) != len(sel.value)):
885
+ raise ValueError(
886
+ f"Constant len {len(xs)} doesn't match input len {len(sel.value)}.")
887
+ return sel.value
888
+
889
+ def eval_selector_and(self, sel: SelectorAnd,
890
+ xs: Sequence[Value]) -> SelectorValue:
891
+ fst_values = self.evaluate(sel.fst, xs)
892
+ snd_values = self.evaluate(sel.snd, xs)
893
+ return np.logical_and(np.array(fst_values), np.array(snd_values)).tolist()
894
+
895
+ def eval_selector_or(self, sel: SelectorOr,
896
+ xs: Sequence[Value]) -> SelectorValue:
897
+ fst_values = self.evaluate(sel.fst, xs)
898
+ snd_values = self.evaluate(sel.snd, xs)
899
+ return np.logical_or(np.array(fst_values), np.array(snd_values)).tolist()
900
+
901
+ def eval_selector_not(self, sel: SelectorNot,
902
+ xs: Sequence[Value]) -> SelectorValue:
903
+ values = self.evaluate(sel.inner, xs)
904
+ return np.logical_not(np.array(values)).tolist()
905
+
906
+
907
+ def _get_selected(
908
+ selector_row: list[bool],
909
+ values: Sequence[VT],
910
+ ) -> Sequence[VT]:
911
+ """Helper for aggregate. [T T F], [a b c] -> [a b]."""
912
+ return [v for s, v in zip(selector_row, values) if s]
913
+
914
+
915
+ def _mean(xs: Sequence[VT], default: VT) -> VT:
916
+ """Takes the mean for numbers and concats for strings."""
917
+ if not xs:
918
+ return default
919
+ exemplar = xs[0]
920
+ if isinstance(exemplar, (int, bool)):
921
+ return sum(xs) / len(xs)
922
+ elif len(xs) == 1:
923
+ return exemplar
924
+ else:
925
+ raise ValueError(f"Unsupported type for aggregation: {xs}")
926
+
927
+
928
+ def _raise_not_implemented(expr: RASPExpr, xs: Sequence[Value]):
929
+ raise NotImplementedError(f"Evaluation of {expr} is not defined.")
930
+
931
+
932
+ evaluate = DefaultRASPEvaluator().evaluate
rasp/rasp_test.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for rasp.rasp."""
16
+
17
+ import itertools
18
+
19
+ from absl.testing import absltest
20
+ from absl.testing import parameterized
21
+ import numpy as np
22
+ from tracr.rasp import rasp
23
+
24
+ # Note that the example text labels must match their default names.
25
+
26
+ _SOP_PRIMITIVE_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
27
+ ("tokens", rasp.tokens),
28
+ ("length", rasp.length),
29
+ ("indices", rasp.indices),
30
+ ]
31
+
32
+ _NONPRIMITIVE_SOP_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
33
+ ("map", rasp.Map(lambda x: x, rasp.tokens)),
34
+ (
35
+ "sequence_map",
36
+ rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens),
37
+ ),
38
+ (
39
+ "linear_sequence_map",
40
+ rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, 0.1, 0.2),
41
+ ),
42
+ (
43
+ "aggregate",
44
+ rasp.Aggregate(
45
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
46
+ rasp.tokens,
47
+ ),
48
+ ),
49
+ (
50
+ "selector_width",
51
+ rasp.SelectorWidth(
52
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)),
53
+ ),
54
+ ]
55
+
56
+ _SOP_EXAMPLES = lambda: _SOP_PRIMITIVE_EXAMPLES() + _NONPRIMITIVE_SOP_EXAMPLES()
57
+
58
+ _SELECTOR_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
59
+ ("select", rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)),
60
+ ("selector_and",
61
+ rasp.SelectorAnd(
62
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
63
+ rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
64
+ )),
65
+ ("selector_or",
66
+ rasp.SelectorOr(
67
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
68
+ rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
69
+ )),
70
+ ("selector_not",
71
+ rasp.SelectorNot(
72
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),)),
73
+ ]
74
+
75
+ _ALL_EXAMPLES = lambda: _SOP_EXAMPLES() + _SELECTOR_EXAMPLES()
76
+
77
+
78
+ class LabelTest(parameterized.TestCase):
79
+
80
+ def test_primitive_labels(self):
81
+ self.assertEqual(rasp.tokens.label, "tokens")
82
+ self.assertEqual(rasp.indices.label, "indices")
83
+ self.assertEqual(rasp.length.label, "length")
84
+
85
+ @parameterized.parameters(*_ALL_EXAMPLES())
86
+ def test_default_names(self, default_name: str, expr: rasp.RASPExpr):
87
+ self.assertEqual(expr.name, default_name)
88
+
89
+
90
+ class SOpTest(parameterized.TestCase):
91
+ """Tests for S-Ops."""
92
+
93
+ @parameterized.parameters(
94
+ ("hello", ["h", "e", "l", "l", "o"]),
95
+ ("h", ["h"]),
96
+ (["h", "e", "l", "l", "o"], ["h", "e", "l", "l", "o"]),
97
+ (["h"], ["h"]),
98
+ ([1, 2], [1, 2]),
99
+ ([0.1, 0.2], [0.1, 0.2]),
100
+ )
101
+ def test_tokens(self, input_sequence, expected):
102
+ self.assertEqual(rasp.tokens(input_sequence), expected)
103
+
104
+ @parameterized.parameters(
105
+ ("hello", [0, 1, 2, 3, 4]),
106
+ ("h", [0]),
107
+ (["h", "e", "l", "l", "o"], [0, 1, 2, 3, 4]),
108
+ (["h"], [0]),
109
+ ([1, 2], [0, 1]),
110
+ ([0.1, 0.2], [0, 1]),
111
+ )
112
+ def test_indices(self, input_sequence, expected):
113
+ self.assertEqual(rasp.indices(input_sequence), expected)
114
+
115
+ @parameterized.parameters(
116
+ ("hello", [5, 5, 5, 5, 5]),
117
+ ("h", [1]),
118
+ (["h", "e", "l", "l", "o"], [5, 5, 5, 5, 5]),
119
+ (["h"], [1]),
120
+ ([1, 2], [2, 2]),
121
+ ([0.1, 0.2], [2, 2]),
122
+ )
123
+ def test_length(self, input_sequence, expected):
124
+ self.assertEqual(rasp.length(input_sequence), expected)
125
+
126
+ def test_prims_are_sops(self):
127
+ self.assertIsInstance(rasp.tokens, rasp.SOp)
128
+ self.assertIsInstance(rasp.indices, rasp.SOp)
129
+ self.assertIsInstance(rasp.length, rasp.SOp)
130
+
131
+ def test_prims_are_raspexprs(self):
132
+ self.assertIsInstance(rasp.tokens, rasp.RASPExpr)
133
+ self.assertIsInstance(rasp.indices, rasp.RASPExpr)
134
+ self.assertIsInstance(rasp.length, rasp.RASPExpr)
135
+
136
+ @parameterized.parameters(
137
+ (lambda x: x + "a", "hello", ["ha", "ea", "la", "la", "oa"]),
138
+ (lambda x: x + "t", "h", ["ht"]),
139
+ (lambda x: x + 1, [1, 2], [2, 3]),
140
+ (lambda x: x / 2, [0.1, 0.2], [0.05, 0.1]),
141
+ )
142
+ def test_map(self, f, input_sequence, expected):
143
+ self.assertEqual(rasp.Map(f, rasp.tokens)(input_sequence), expected)
144
+
145
+ def test_nested_elementwise_ops_results_in_only_one_map_object(self):
146
+ map_sop = ((rasp.tokens * 2) + 2) / 2
147
+ self.assertEqual(map_sop.inner, rasp.tokens)
148
+ self.assertEqual(map_sop([1]), [2])
149
+
150
+ @parameterized.parameters(
151
+ (lambda x, y: x + y, "hello", ["hh", "ee", "ll", "ll", "oo"]),
152
+ (lambda x, y: x + y, "h", ["hh"]),
153
+ (lambda x, y: x + y, [1, 2], [2, 4]),
154
+ (lambda x, y: x * y, [1, 2], [1, 4]),
155
+ )
156
+ def test_sequence_map(self, f, input_sequence, expected):
157
+ self.assertEqual(
158
+ rasp.SequenceMap(f, rasp.tokens, rasp.tokens)(input_sequence), expected)
159
+
160
+ def test_sequence_map_with_same_inputs_logs_warning(self):
161
+ with self.assertLogs(level="WARNING"):
162
+ rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens)
163
+
164
+ @parameterized.parameters(
165
+ (1, 1, [1, 2], [2, 4]),
166
+ (1, -1, [1, 2], [0, 0]),
167
+ (1, -2, [1, 2], [-1, -2]),
168
+ )
169
+ def test_linear_sequence_map(self, fst_fac, snd_fac, input_sequence,
170
+ expected):
171
+ self.assertEqual(
172
+ rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, fst_fac,
173
+ snd_fac)(input_sequence), expected)
174
+
175
+ @parameterized.parameters(
176
+ ([5, 5, 5, 5, 5], "hello", [5, 5, 5, 5, 5]),
177
+ (["e"], "h", ["e"]),
178
+ ([1, 2, 3, 4, 5], ["h", "e", "l", "l", "o"], [1, 2, 3, 4, 5]),
179
+ ([2, 2], [1, 2], [2, 2]),
180
+ )
181
+ def test_constant(self, const, input_sequence, expected):
182
+ self.assertEqual(rasp.ConstantSOp(const)(input_sequence), expected)
183
+
184
+ def test_constant_complains_if_sizes_dont_match(self):
185
+ with self.assertRaisesRegex(
186
+ ValueError,
187
+ r"^.*Constant len .* doesn't match input len .*$",):
188
+ rasp.ConstantSOp([1, 2, 3])("longer string")
189
+
190
+ def test_can_turn_off_constant_complaints(self):
191
+ rasp.ConstantSOp([1, 2, 3], check_length=False)("longer string")
192
+
193
+ def test_numeric_dunders(self):
194
+ # We don't check all the cases here -- only a few representative ones.
195
+ self.assertEqual(
196
+ (rasp.tokens > 1)([0, 1, 2]),
197
+ [0, 0, 1],
198
+ )
199
+ self.assertEqual(
200
+ (1 < rasp.tokens)([0, 1, 2]),
201
+ [0, 0, 1],
202
+ )
203
+ self.assertEqual(
204
+ (rasp.tokens < 1)([0, 1, 2]),
205
+ [1, 0, 0],
206
+ )
207
+ self.assertEqual(
208
+ (1 > rasp.tokens)([0, 1, 2]),
209
+ [1, 0, 0],
210
+ )
211
+ self.assertEqual(
212
+ (rasp.tokens == 1)([0, 1, 2]),
213
+ [0, 1, 0],
214
+ )
215
+ self.assertEqual(
216
+ (rasp.tokens + 1)([0, 1, 2]),
217
+ [1, 2, 3],
218
+ )
219
+ self.assertEqual(
220
+ (1 + rasp.tokens)([0, 1, 2]),
221
+ [1, 2, 3],
222
+ )
223
+
224
+ def test_dunders_with_sop(self):
225
+ self.assertEqual(
226
+ (rasp.tokens + rasp.indices)([0, 1, 2]),
227
+ [0, 2, 4],
228
+ )
229
+ self.assertEqual(
230
+ (rasp.length - 1 - rasp.indices)([0, 1, 2]),
231
+ [2, 1, 0],
232
+ )
233
+ self.assertEqual(
234
+ (rasp.length * rasp.length)([0, 1, 2]),
235
+ [9, 9, 9],
236
+ )
237
+
238
+ def test_logical_dunders(self):
239
+ self.assertEqual(
240
+ (rasp.tokens & True)([True, False]),
241
+ [True, False],
242
+ )
243
+ self.assertEqual(
244
+ (rasp.tokens & False)([True, False]),
245
+ [False, False],
246
+ )
247
+ self.assertEqual(
248
+ (rasp.tokens | True)([True, False]),
249
+ [True, True],
250
+ )
251
+ self.assertEqual(
252
+ (rasp.tokens | False)([True, False]),
253
+ [True, False],
254
+ )
255
+ self.assertEqual(
256
+ (True & rasp.tokens)([True, False]),
257
+ [True, False],
258
+ )
259
+ self.assertEqual(
260
+ (False & rasp.tokens)([True, False]),
261
+ [False, False],
262
+ )
263
+ self.assertEqual(
264
+ (True | rasp.tokens)([True, False]),
265
+ [True, True],
266
+ )
267
+ self.assertEqual(
268
+ (False | rasp.tokens)([True, False]),
269
+ [True, False],
270
+ )
271
+
272
+ self.assertEqual(
273
+ (~rasp.tokens)([True, False]),
274
+ [False, True],
275
+ )
276
+
277
+ self.assertEqual(
278
+ (rasp.ConstantSOp([True, True, False, False])
279
+ & rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]),
280
+ [True, False, False, False],
281
+ )
282
+
283
+ self.assertEqual(
284
+ (rasp.ConstantSOp([True, True, False, False])
285
+ | rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]),
286
+ [True, True, True, False],
287
+ )
288
+
289
+
290
+ class EncodingTest(parameterized.TestCase):
291
+ """Tests for SOp encodings."""
292
+
293
+ @parameterized.named_parameters(*_SOP_EXAMPLES())
294
+ def test_all_sops_are_categorical_by_default(self, sop: rasp.SOp):
295
+ self.assertTrue(rasp.is_categorical(sop))
296
+
297
+ @parameterized.named_parameters(*_SOP_EXAMPLES())
298
+ def test_is_numerical(self, sop: rasp.SOp):
299
+ self.assertTrue(rasp.is_numerical(rasp.numerical(sop)))
300
+ self.assertFalse(rasp.is_numerical(rasp.categorical(sop)))
301
+
302
+ @parameterized.named_parameters(*_SOP_EXAMPLES())
303
+ def test_is_categorical(self, sop: rasp.SOp):
304
+ self.assertTrue(rasp.is_categorical(rasp.categorical(sop)))
305
+ self.assertFalse(rasp.is_categorical(rasp.numerical(sop)))
306
+
307
+ @parameterized.named_parameters(*_SOP_EXAMPLES())
308
+ def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp):
309
+ num_sop = rasp.numerical(sop)
310
+ cat_num_sop = rasp.categorical(num_sop)
311
+ self.assertTrue(rasp.is_numerical(num_sop))
312
+ self.assertTrue(rasp.is_categorical(cat_num_sop))
313
+
314
+
315
+ class SelectorTest(parameterized.TestCase):
316
+ """Tests for Selectors."""
317
+
318
+ def test_select_eq_has_correct_value(self):
319
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
320
+ self.assertEqual(
321
+ selector("hey"), [
322
+ [True, False, False],
323
+ [False, True, False],
324
+ [False, False, True],
325
+ ])
326
+
327
+ def test_select_lt_has_correct_value(self):
328
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT)
329
+ self.assertEqual(selector([0, 1]), [
330
+ [False, False],
331
+ [True, False],
332
+ ])
333
+
334
+ def test_select_leq_has_correct_value(self):
335
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ)
336
+ self.assertEqual(selector([0, 1]), [
337
+ [True, False],
338
+ [True, True],
339
+ ])
340
+
341
+ def test_select_gt_has_correct_value(self):
342
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GT)
343
+ self.assertEqual(selector([0, 1]), [
344
+ [False, True],
345
+ [False, False],
346
+ ])
347
+
348
+ def test_select_geq_has_correct_value(self):
349
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GEQ)
350
+ self.assertEqual(selector([0, 1]), [
351
+ [True, True],
352
+ [False, True],
353
+ ])
354
+
355
+ def test_select_neq_has_correct_value(self):
356
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.NEQ)
357
+ self.assertEqual(selector([0, 1]), [
358
+ [False, True],
359
+ [True, False],
360
+ ])
361
+
362
+ def test_select_true_has_correct_value(self):
363
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
364
+ self.assertEqual(selector([0, 1]), [
365
+ [True, True],
366
+ [True, True],
367
+ ])
368
+
369
+ def test_select_false_has_correct_value(self):
370
+ selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.FALSE)
371
+ self.assertEqual(selector([0, 1]), [
372
+ [False, False],
373
+ [False, False],
374
+ ])
375
+
376
+ def test_selector_and_gets_simplified_when_keys_and_queries_match(self):
377
+ selector = rasp.selector_and(
378
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
379
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ),
380
+ )
381
+ self.assertIsInstance(selector, rasp.Select)
382
+ self.assertIs(selector.keys, rasp.tokens)
383
+ self.assertIs(selector.queries, rasp.indices)
384
+
385
+ def test_selector_and_doesnt_get_simplified_when_keys_queries_different(self):
386
+ selector = rasp.selector_and(
387
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
388
+ rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
389
+ )
390
+ self.assertIsInstance(selector, rasp.SelectorAnd)
391
+
392
+ def test_selector_and_gets_simplified_when_keys_are_full(self):
393
+ selector = rasp.selector_and(
394
+ rasp.Select(rasp.Full(1), rasp.indices, rasp.Comparison.GEQ),
395
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ),
396
+ )
397
+ self.assertIsInstance(selector, rasp.Select)
398
+ self.assertIs(selector.keys, rasp.tokens)
399
+ self.assertIs(selector.queries, rasp.indices)
400
+
401
+ def test_selector_and_gets_simplified_when_queries_are_full(self):
402
+ selector = rasp.selector_and(
403
+ rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
404
+ rasp.Select(rasp.tokens, rasp.Full(1), rasp.Comparison.LEQ),
405
+ )
406
+ self.assertIsInstance(selector, rasp.Select)
407
+ self.assertIs(selector.keys, rasp.tokens)
408
+ self.assertIs(selector.queries, rasp.indices)
409
+
410
+ @parameterized.parameters(
411
+ itertools.product(
412
+ (rasp.tokens, rasp.indices, rasp.Full(1)),
413
+ (rasp.tokens, rasp.indices, rasp.Full(1)),
414
+ list(rasp.Comparison),
415
+ (rasp.tokens, rasp.indices, rasp.Full(1)),
416
+ (rasp.tokens, rasp.indices, rasp.Full(1)),
417
+ list(rasp.Comparison),
418
+ ))
419
+ def test_simplified_selector_and_works_the_same_way_as_not(
420
+ self, fst_k, fst_q, fst_p, snd_k, snd_q, snd_p):
421
+ fst = rasp.Select(fst_k, fst_q, fst_p)
422
+ snd = rasp.Select(snd_k, snd_q, snd_p)
423
+
424
+ simplified = rasp.selector_and(fst, snd)([0, 1, 2, 3])
425
+ not_simplified = rasp.selector_and(fst, snd, simplify=False)([0, 1, 2, 3])
426
+
427
+ np.testing.assert_array_equal(
428
+ np.array(simplified),
429
+ np.array(not_simplified),
430
+ )
431
+
432
+ def test_select_is_selector(self):
433
+ self.assertIsInstance(
434
+ rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
435
+ rasp.Selector,
436
+ )
437
+
438
+ def test_select_is_raspexpr(self):
439
+ self.assertIsInstance(
440
+ rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
441
+ rasp.RASPExpr,
442
+ )
443
+
444
+ def test_constant_selector(self):
445
+ self.assertEqual(
446
+ rasp.ConstantSelector([[True, True], [False, False]])([1, 2]),
447
+ [[True, True], [False, False]],
448
+ )
449
+
450
+
451
+ class CopyTest(parameterized.TestCase):
452
+
453
+ @parameterized.named_parameters(*_ALL_EXAMPLES())
454
+ def test_copy_preserves_name(self, expr: rasp.RASPExpr):
455
+ expr = expr.named("foo")
456
+ self.assertEqual(expr.copy().name, expr.name)
457
+
458
+ @parameterized.named_parameters(*_ALL_EXAMPLES())
459
+ def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr):
460
+ expr = expr.named("foo")
461
+ expr.copy().named("bar")
462
+ self.assertEqual(expr.name, "foo")
463
+
464
+ @parameterized.named_parameters(*_ALL_EXAMPLES())
465
+ def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr):
466
+ expr = expr.named("foo")
467
+ copy = expr.copy()
468
+ expr.named("bar")
469
+ self.assertEqual(copy.name, "foo")
470
+
471
+ @parameterized.named_parameters(*_ALL_EXAMPLES())
472
+ def test_copy_changes_id(self, expr: rasp.RASPExpr):
473
+ self.assertNotEqual(expr.copy().unique_id, expr.unique_id)
474
+
475
+ @parameterized.named_parameters(*_ALL_EXAMPLES())
476
+ def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr):
477
+ copy_child_ids = [c.unique_id for c in expr.copy().children]
478
+ child_ids = [c.unique_id for c in expr.children]
479
+ for child_id, copy_child_id in zip(child_ids, copy_child_ids):
480
+ self.assertEqual(child_id, copy_child_id)
481
+
482
+
483
+ class AggregateTest(parameterized.TestCase):
484
+ """Tests for Aggregate."""
485
+
486
+ @parameterized.parameters(
487
+ dict(
488
+ selector=rasp.ConstantSelector([
489
+ [True, False],
490
+ [False, True],
491
+ ]),
492
+ sop=rasp.ConstantSOp(["h", "e"]),
493
+ default=None,
494
+ expected_value=["h", "e"],
495
+ ),
496
+ dict(
497
+ selector=rasp.ConstantSelector([
498
+ [False, True],
499
+ [False, False],
500
+ ]),
501
+ sop=rasp.ConstantSOp(["h", "e"]),
502
+ default=None,
503
+ expected_value=["e", None],
504
+ ),
505
+ dict(
506
+ selector=rasp.ConstantSelector([
507
+ [True, False],
508
+ [False, False],
509
+ ]),
510
+ sop=rasp.ConstantSOp(["h", "e"]),
511
+ default=None,
512
+ expected_value=["h", None],
513
+ ),
514
+ dict(
515
+ selector=rasp.ConstantSelector([
516
+ [True, True],
517
+ [False, True],
518
+ ]),
519
+ sop=rasp.ConstantSOp([0, 1]),
520
+ default=0,
521
+ expected_value=[0.5, 1],
522
+ ),
523
+ dict(
524
+ selector=rasp.ConstantSelector([
525
+ [False, False],
526
+ [True, True],
527
+ ]),
528
+ sop=rasp.ConstantSOp([0, 1]),
529
+ default=0,
530
+ expected_value=[0, 0.5],
531
+ ),
532
+ dict(
533
+ selector=rasp.ConstantSelector([
534
+ [False, False],
535
+ [True, True],
536
+ ]),
537
+ sop=rasp.ConstantSOp([0, 1]),
538
+ default=None,
539
+ expected_value=[None, 0.5],
540
+ ),
541
+ )
542
+ def test_aggregate_on_size_2_inputs(self, selector, sop, default,
543
+ expected_value):
544
+ # The 0, 0 input is ignored as it's overridden by the constant SOps.
545
+ self.assertEqual(
546
+ rasp.Aggregate(selector, sop, default)([0, 0]),
547
+ expected_value,
548
+ )
549
+
550
+
551
+ class RaspProgramTest(parameterized.TestCase):
552
+ """Each testcase implements and tests a RASP program."""
553
+
554
+ def test_has_prev(self):
555
+
556
+ def has_prev(seq: rasp.SOp) -> rasp.SOp:
557
+ prev_copy = rasp.SelectorAnd(
558
+ rasp.Select(seq, seq, rasp.Comparison.EQ),
559
+ rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LT),
560
+ )
561
+ return rasp.Aggregate(prev_copy, rasp.Full(1), default=0) > 0
562
+
563
+ self.assertEqual(
564
+ has_prev(rasp.tokens)("hello"),
565
+ [0, 0, 0, 1, 0],
566
+ )
567
+
568
+ self.assertEqual(
569
+ has_prev(rasp.tokens)("helllo"),
570
+ [0, 0, 0, 1, 1, 0],
571
+ )
572
+
573
+ self.assertEqual(
574
+ has_prev(rasp.tokens)([0, 2, 3, 2, 1, 0, 2]),
575
+ [0, 0, 0, 1, 0, 1, 1],
576
+ )
577
+
578
+
579
+ if __name__ == "__main__":
580
+ absltest.main()
transformer/attention.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Instrumented attention layer (forked from the Haiku library implementation).
16
+ """
17
+
18
+ from typing import Optional
19
+ import warnings
20
+
21
+ import chex
22
+ import haiku as hk
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ @chex.dataclass
29
+ class AttentionOutput:
30
+ out: jax.Array # [..., T', D']
31
+ logits: jax.Array # [..., H, T', T]
32
+
33
+
34
+ class MultiHeadAttention(hk.Module):
35
+ """Multi-headed attention (MHA) module.
36
+
37
+ This module is intended for attending over sequences of vectors.
38
+
39
+ Rough sketch:
40
+ - Compute keys (K), queries (Q), and values (V) as projections of inputs.
41
+ - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)).
42
+ - Output is another projection of WV^T.
43
+
44
+ For more detail, see the original Transformer paper:
45
+ "Attention is all you need" https://arxiv.org/abs/1706.03762.
46
+
47
+ Glossary of shapes:
48
+ - T: Sequence length.
49
+ - D: Vector (embedding) size.
50
+ - H: Number of attention heads.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ num_heads: int,
56
+ key_size: int,
57
+ # TODO(b/240019186): Remove `w_init_scale`.
58
+ w_init_scale: Optional[float] = None,
59
+ *,
60
+ w_init: Optional[hk.initializers.Initializer] = None,
61
+ value_size: Optional[int] = None,
62
+ model_size: Optional[int] = None,
63
+ name: Optional[str] = None,
64
+ ):
65
+ """Initialises the module.
66
+
67
+ Args:
68
+ num_heads: Number of independent attention heads (H).
69
+ key_size: The size of keys (K) and queries used for attention.
70
+ w_init_scale: DEPRECATED. Please use w_init instead.
71
+ w_init: Initialiser for weights in the linear map.
72
+ value_size: Optional size of the value projection (V). If None, defaults
73
+ to the key size (K).
74
+ model_size: Optional size of the output embedding (D'). If None, defaults
75
+ to the key size multiplied by the number of heads (K * H).
76
+ name: Optional name for this module.
77
+ """
78
+ super().__init__(name=name)
79
+ self.num_heads = num_heads
80
+ self.key_size = key_size
81
+ self.value_size = value_size or key_size
82
+ self.model_size = model_size or key_size * num_heads
83
+
84
+ # Backwards-compatibility for w_init_scale.
85
+ if w_init_scale is not None:
86
+ warnings.warn(
87
+ "w_init_scale is deprecated; please pass an explicit weight "
88
+ "initialiser instead.", DeprecationWarning)
89
+ if w_init and w_init_scale:
90
+ raise ValueError("Please provide only `w_init`, not `w_init_scale`.")
91
+ if w_init is None and w_init_scale is None:
92
+ raise ValueError("Please provide a weight initializer: `w_init`.")
93
+ if w_init is None:
94
+ w_init = hk.initializers.VarianceScaling(w_init_scale)
95
+ self.w_init = w_init
96
+
97
+ def __call__(
98
+ self,
99
+ query: jnp.ndarray,
100
+ key: jnp.ndarray,
101
+ value: jnp.ndarray,
102
+ mask: Optional[jnp.ndarray] = None,
103
+ ) -> AttentionOutput:
104
+ """Computes (optionally masked) MHA with queries, keys & values.
105
+
106
+ This module broadcasts over zero or more 'batch-like' leading dimensions.
107
+
108
+ Args:
109
+ query: Embeddings sequence used to compute queries; shape [..., T', D_q].
110
+ key: Embeddings sequence used to compute keys; shape [..., T, D_k].
111
+ value: Embeddings sequence used to compute values; shape [..., T, D_v].
112
+ mask: Optional mask applied to attention weights; shape [..., H=1, T', T].
113
+
114
+ Returns:
115
+ A new sequence of embeddings, consisting of a projection of the
116
+ attention-weighted value projections; shape [..., T', D'].
117
+ """
118
+
119
+ # In shape hints below, we suppress the leading dims [...] for brevity.
120
+ # Hence e.g. [A, B] should be read in every case as [..., A, B].
121
+ *leading_dims, sequence_length, _ = query.shape
122
+ projection = self._linear_projection
123
+
124
+ # Compute key/query/values (overload K/Q/V to denote the respective sizes).
125
+ query_heads = projection(query, self.key_size, "query") # [T', H, Q=K]
126
+ key_heads = projection(key, self.key_size, "key") # [T, H, K]
127
+ value_heads = projection(value, self.value_size, "value") # [T, H, V]
128
+
129
+ # Compute attention weights.
130
+ attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads)
131
+ attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype)
132
+ if mask is not None:
133
+ if mask.ndim != attn_logits.ndim:
134
+ raise ValueError(
135
+ f"Mask dimensionality {mask.ndim} must match logits dimensionality "
136
+ f"{attn_logits.ndim}.")
137
+ attn_logits = jnp.where(mask, attn_logits, -1e30)
138
+ attn_weights = jax.nn.softmax(attn_logits) # [H, T', T]
139
+
140
+ # Weight the values by the attention and flatten the head vectors.
141
+ attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
142
+ attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V]
143
+
144
+ # Apply another projection to get the final embeddings.
145
+ final_projection = hk.Linear(self.model_size, w_init=self.w_init)
146
+ return AttentionOutput(
147
+ out=final_projection(attn),
148
+ logits=attn_logits,
149
+ )
150
+
151
+ @hk.transparent
152
+ def _linear_projection(
153
+ self,
154
+ x: jnp.ndarray,
155
+ head_size: int,
156
+ name: Optional[str] = None,
157
+ ) -> jnp.ndarray:
158
+ y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x)
159
+ *leading_dims, _ = x.shape
160
+ return y.reshape((*leading_dims, self.num_heads, head_size))
transformer/compressed_model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Modified transformer to learn a linear compression of the residual stream.
16
+
17
+ CompressedTransformer adds three arguments compared to Transformer:
18
+ - embedding_size: the size of the compressed residual stream.
19
+ - unembed_at_every_layer: whether to apply the unembedding before applying
20
+ attention and MLP layers
21
+ - return_activations: whether to return all model activations rather than just
22
+ the outputs
23
+ """
24
+
25
+ import collections
26
+ import dataclasses
27
+ from typing import Optional
28
+
29
+ import haiku as hk
30
+ import jax
31
+ import numpy as np
32
+
33
+ from tracr.transformer import attention
34
+ from tracr.transformer import model
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class CompressedTransformer(hk.Module):
39
+ """A transformer stack with linearly compressed residual stream."""
40
+
41
+ config: model.TransformerConfig
42
+ name: Optional[str] = None
43
+
44
+ def __call__(
45
+ self,
46
+ embeddings: jax.Array, # [B, T, D]
47
+ mask: jax.Array, # [B, T]
48
+ *,
49
+ use_dropout: bool = True,
50
+ embedding_size: Optional[int] = None,
51
+ unembed_at_every_layer: bool = False,
52
+ ) -> model.TransformerOutput: # [B, T, D]
53
+ """Transforms input embedding sequences to output embedding sequences.
54
+
55
+ Args:
56
+ embeddings: Input embeddings to pass through the model.
57
+ mask: Boolean mask to restrict the inputs the model uses.
58
+ use_dropout: Turns dropout on/off.
59
+ embedding_size: Dimension to compress the residual stream to.
60
+ unembed_at_every_layer: Whether to unembed the residual stream when
61
+ reading the input for every layer (keeping the layer input sizes) or to
62
+ only unembed before the model output (compressing the layer inputs).
63
+
64
+ Returns:
65
+ The outputs of the forward pass through the transformer.
66
+ """
67
+
68
+ def layer_norm(x: jax.Array) -> jax.Array:
69
+ """Applies a unique LayerNorm to x with default settings."""
70
+ if self.config.layer_norm:
71
+ return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
72
+ return x
73
+
74
+ initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers)
75
+ dropout_rate = self.config.dropout_rate if use_dropout else 0.
76
+ _, seq_len, model_size = embeddings.shape
77
+
78
+ # To compress the model, we multiply with a matrix W when reading from
79
+ # the residual stream, and with W^T when writing to the residual stream.
80
+ if embedding_size is not None:
81
+ # [to_size, from_size]
82
+ w_emb = hk.get_parameter(
83
+ "w_emb", (embedding_size, model_size),
84
+ init=hk.initializers.RandomNormal())
85
+
86
+ write_to_residual = lambda x: x @ w_emb.T
87
+ read_from_residual = lambda x: x @ w_emb
88
+
89
+ if not unembed_at_every_layer:
90
+ model_size = embedding_size
91
+ else:
92
+ write_to_residual = lambda x: x
93
+ read_from_residual = lambda x: x
94
+
95
+ # Compute causal mask for autoregressive sequence modelling.
96
+ mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
97
+ mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T]
98
+
99
+ if self.config.causal:
100
+ causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T]
101
+ causal_mask = np.tril(causal_mask)
102
+ mask = mask * causal_mask # [B, H=1, T, T]
103
+
104
+ # Set up activation collection.
105
+ collected = collections.defaultdict(list)
106
+
107
+ def collect(**kwargs):
108
+ for k, v in kwargs.items():
109
+ collected[k].append(v)
110
+
111
+ residual = write_to_residual(embeddings)
112
+
113
+ for layer in range(self.config.num_layers):
114
+ with hk.experimental.name_scope(f"layer_{layer}"):
115
+ # First the attention block.
116
+ attn_block = attention.MultiHeadAttention(
117
+ num_heads=self.config.num_heads,
118
+ key_size=self.config.key_size,
119
+ model_size=model_size,
120
+ w_init=initializer,
121
+ name="attn")
122
+
123
+ attn_in = residual
124
+ if unembed_at_every_layer:
125
+ attn_in = read_from_residual(attn_in)
126
+ attn_in = layer_norm(attn_in)
127
+ attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask)
128
+ attn_out, attn_logits = attn_out.out, attn_out.logits
129
+ if dropout_rate > 0:
130
+ attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out)
131
+
132
+ if unembed_at_every_layer:
133
+ collect(layer_outputs=attn_out, attn_logits=attn_logits)
134
+ else:
135
+ collect(
136
+ layer_outputs=read_from_residual(attn_out),
137
+ attn_logits=attn_logits,
138
+ )
139
+
140
+ if unembed_at_every_layer:
141
+ attn_out = write_to_residual(attn_out)
142
+ residual = residual + attn_out
143
+
144
+ collect(residuals=residual)
145
+
146
+ # Then the dense block.
147
+ with hk.experimental.name_scope("mlp"):
148
+ dense_block = hk.Sequential([
149
+ hk.Linear(
150
+ self.config.mlp_hidden_size,
151
+ w_init=initializer,
152
+ name="linear_1"),
153
+ self.config.activation_function,
154
+ hk.Linear(model_size, w_init=initializer, name="linear_2"),
155
+ ])
156
+
157
+ dense_in = residual
158
+ if unembed_at_every_layer:
159
+ dense_in = read_from_residual(dense_in)
160
+ dense_in = layer_norm(dense_in)
161
+ dense_out = dense_block(dense_in)
162
+ if dropout_rate > 0:
163
+ dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out)
164
+
165
+ if unembed_at_every_layer:
166
+ collect(layer_outputs=dense_out)
167
+ else:
168
+ collect(layer_outputs=read_from_residual(dense_out))
169
+
170
+ if unembed_at_every_layer:
171
+ dense_out = write_to_residual(dense_out)
172
+ residual = residual + dense_out
173
+
174
+ collect(residuals=residual)
175
+
176
+ output = read_from_residual(residual)
177
+ output = layer_norm(output)
178
+
179
+ return model.TransformerOutput(
180
+ layer_outputs=collected["layer_outputs"],
181
+ residuals=collected["residuals"],
182
+ attn_logits=collected["attn_logits"],
183
+ output=output,
184
+ input_embeddings=embeddings,
185
+ )
transformer/compressed_model_test.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for transformer.model."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import haiku as hk
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from tracr.transformer import compressed_model
24
+ from tracr.transformer import model
25
+
26
+
27
+ class CompressedTransformerTest(parameterized.TestCase):
28
+
29
+ def _check_layer_naming(self, params):
30
+ # Modules should be named for example
31
+ # For MLPs: "compressed_transformer/layer_{i}/mlp/linear_1"
32
+ # For Attention: "compressed_transformer/layer_{i}/attn/key"
33
+ # For Layer Norm: "compressed_transformer/layer_{i}/layer_norm"
34
+ for key in params.keys():
35
+ levels = key.split("/")
36
+ self.assertEqual(levels[0], "compressed_transformer")
37
+ if len(levels) == 1:
38
+ self.assertEqual(list(params[key].keys()), ["w_emb"])
39
+ continue
40
+ if levels[1].startswith("layer_norm"):
41
+ continue # output layer norm
42
+ self.assertStartsWith(levels[1], "layer")
43
+ if levels[2] == "mlp":
44
+ self.assertIn(levels[3], {"linear_1", "linear_2"})
45
+ elif levels[2] == "attn":
46
+ self.assertIn(levels[3], {"key", "query", "value", "linear"})
47
+ else:
48
+ self.assertStartsWith(levels[2], "layer_norm")
49
+
50
+ def _zero_mlps(self, params):
51
+ for module in params:
52
+ if "mlp" in module:
53
+ for param in params[module]:
54
+ params[module][param] = jnp.zeros_like(params[module][param])
55
+ return params
56
+
57
+ @parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False))
58
+ def test_layer_norm(self, layer_norm):
59
+ # input = [1, 1, 1, 1]
60
+ # If layer norm is used, this should give all-0 output for a freshly
61
+ # initialized model because LN will subtract the mean after each layer.
62
+ # Else we expect non-zero outputs.
63
+
64
+ @hk.transform
65
+ def forward(emb, mask):
66
+ transformer = compressed_model.CompressedTransformer(
67
+ model.TransformerConfig(
68
+ num_heads=2,
69
+ num_layers=2,
70
+ key_size=5,
71
+ mlp_hidden_size=64,
72
+ dropout_rate=0.,
73
+ layer_norm=layer_norm))
74
+ return transformer(emb, mask).output
75
+
76
+ seq_len = 4
77
+ emb = jnp.ones((1, seq_len, 1))
78
+ mask = jnp.ones((1, seq_len))
79
+ rng = hk.PRNGSequence(1)
80
+ params = forward.init(next(rng), emb, mask)
81
+ out = forward.apply(params, next(rng), emb, mask)
82
+
83
+ self._check_layer_naming(params)
84
+ if layer_norm:
85
+ np.testing.assert_allclose(out, 0)
86
+ else:
87
+ self.assertFalse(np.allclose(out, 0))
88
+
89
+ @parameterized.parameters(dict(causal=True), dict(causal=False))
90
+ def test_causal_attention(self, causal):
91
+ # input = [0, random, random, random]
92
+ # mask = [1, 0, 1, 1]
93
+ # For causal attention the second token can only attend to the first one, so
94
+ # it should be the same. For non-causal attention all tokens should change.
95
+
96
+ @hk.transform
97
+ def forward(emb, mask):
98
+ transformer = compressed_model.CompressedTransformer(
99
+ model.TransformerConfig(
100
+ num_heads=2,
101
+ num_layers=2,
102
+ key_size=5,
103
+ mlp_hidden_size=64,
104
+ dropout_rate=0.,
105
+ layer_norm=False,
106
+ causal=causal))
107
+ return transformer(emb, mask).output
108
+
109
+ seq_len = 4
110
+ emb = np.random.random((1, seq_len, 1))
111
+ emb[:, 0, :] = 0
112
+ mask = np.array([[1, 0, 1, 1]])
113
+ emb, mask = jnp.array(emb), jnp.array(mask)
114
+
115
+ rng = hk.PRNGSequence(1)
116
+ params = forward.init(next(rng), emb, mask)
117
+ params = self._zero_mlps(params)
118
+ out = forward.apply(params, next(rng), emb, mask)
119
+
120
+ self._check_layer_naming(params)
121
+ if causal:
122
+ self.assertEqual(0, out[0, 0, 0])
123
+ self.assertEqual(emb[0, 1, 0], out[0, 1, 0])
124
+ else:
125
+ self.assertNotEqual(0, out[0, 0, 0])
126
+ self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0])
127
+ self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0])
128
+ self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0])
129
+
130
+ def test_setting_activation_function_to_zero(self):
131
+ # An activation function that always returns zeros should result in the
132
+ # same model output as setting all MLP weights to zero.
133
+
134
+ @hk.transform
135
+ def forward_zero(emb, mask):
136
+ transformer = compressed_model.CompressedTransformer(
137
+ model.TransformerConfig(
138
+ num_heads=2,
139
+ num_layers=2,
140
+ key_size=5,
141
+ mlp_hidden_size=64,
142
+ dropout_rate=0.,
143
+ causal=False,
144
+ layer_norm=False,
145
+ activation_function=jnp.zeros_like))
146
+ return transformer(emb, mask).output
147
+
148
+ @hk.transform
149
+ def forward(emb, mask):
150
+ transformer = compressed_model.CompressedTransformer(
151
+ model.TransformerConfig(
152
+ num_heads=2,
153
+ num_layers=2,
154
+ key_size=5,
155
+ mlp_hidden_size=64,
156
+ dropout_rate=0.,
157
+ causal=False,
158
+ layer_norm=False,
159
+ activation_function=jax.nn.gelu))
160
+ return transformer(emb, mask).output
161
+
162
+ seq_len = 4
163
+ emb = np.random.random((1, seq_len, 1))
164
+ mask = np.ones((1, seq_len))
165
+ emb, mask = jnp.array(emb), jnp.array(mask)
166
+
167
+ rng = hk.PRNGSequence(1)
168
+ params = forward.init(next(rng), emb, mask)
169
+ params_no_mlps = self._zero_mlps(params)
170
+
171
+ out_zero_activation = forward_zero.apply(params, next(rng), emb, mask)
172
+ out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask)
173
+
174
+ self._check_layer_naming(params)
175
+ np.testing.assert_allclose(out_zero_activation, out_no_mlps)
176
+ self.assertFalse(np.allclose(out_zero_activation, 0))
177
+
178
+ def test_not_setting_embedding_size_produces_same_output_as_default_model(
179
+ self):
180
+ config = model.TransformerConfig(
181
+ num_heads=2,
182
+ num_layers=2,
183
+ key_size=5,
184
+ mlp_hidden_size=64,
185
+ dropout_rate=0.,
186
+ causal=False,
187
+ layer_norm=False)
188
+
189
+ @hk.without_apply_rng
190
+ @hk.transform
191
+ def forward_model(emb, mask):
192
+ return model.Transformer(config)(emb, mask).output
193
+
194
+ @hk.without_apply_rng
195
+ @hk.transform
196
+ def forward_superposition(emb, mask):
197
+ return compressed_model.CompressedTransformer(config)(emb, mask).output
198
+
199
+ seq_len = 4
200
+ emb = np.random.random((1, seq_len, 1))
201
+ mask = np.ones((1, seq_len))
202
+ emb, mask = jnp.array(emb), jnp.array(mask)
203
+
204
+ rng = hk.PRNGSequence(1)
205
+ params = forward_model.init(next(rng), emb, mask)
206
+ params_superposition = {
207
+ k.replace("transformer", "compressed_transformer"): v
208
+ for k, v in params.items()
209
+ }
210
+
211
+ out_model = forward_model.apply(params, emb, mask)
212
+ out_superposition = forward_superposition.apply(params_superposition, emb,
213
+ mask)
214
+
215
+ self._check_layer_naming(params_superposition)
216
+ np.testing.assert_allclose(out_model, out_superposition)
217
+
218
+ @parameterized.parameters(
219
+ dict(embedding_size=2, unembed_at_every_layer=True),
220
+ dict(embedding_size=2, unembed_at_every_layer=False),
221
+ dict(embedding_size=6, unembed_at_every_layer=True),
222
+ dict(embedding_size=6, unembed_at_every_layer=False))
223
+ def test_embbeding_size_produces_correct_shape_of_residuals_and_layer_outputs(
224
+ self, embedding_size, unembed_at_every_layer):
225
+
226
+ @hk.transform
227
+ def forward(emb, mask):
228
+ transformer = compressed_model.CompressedTransformer(
229
+ model.TransformerConfig(
230
+ num_heads=2,
231
+ num_layers=2,
232
+ key_size=5,
233
+ mlp_hidden_size=64,
234
+ dropout_rate=0.,
235
+ causal=False,
236
+ layer_norm=False))
237
+ return transformer(
238
+ emb,
239
+ mask,
240
+ embedding_size=embedding_size,
241
+ unembed_at_every_layer=unembed_at_every_layer,
242
+ )
243
+
244
+ seq_len = 4
245
+ model_size = 16
246
+
247
+ emb = np.random.random((1, seq_len, model_size))
248
+ mask = np.ones((1, seq_len))
249
+ emb, mask = jnp.array(emb), jnp.array(mask)
250
+
251
+ rng = hk.PRNGSequence(1)
252
+ params = forward.init(next(rng), emb, mask)
253
+ activations = forward.apply(params, next(rng), emb, mask)
254
+
255
+ self._check_layer_naming(params)
256
+
257
+ for residual in activations.residuals:
258
+ self.assertEqual(residual.shape, (1, seq_len, embedding_size))
259
+
260
+ for layer_output in activations.layer_outputs:
261
+ self.assertEqual(layer_output.shape, (1, seq_len, model_size))
262
+
263
+ @parameterized.parameters(
264
+ dict(model_size=2, unembed_at_every_layer=True),
265
+ dict(model_size=2, unembed_at_every_layer=False),
266
+ dict(model_size=6, unembed_at_every_layer=True),
267
+ dict(model_size=6, unembed_at_every_layer=False))
268
+ def test_identity_embedding_produces_same_output_as_standard_model(
269
+ self, model_size, unembed_at_every_layer):
270
+
271
+ config = model.TransformerConfig(
272
+ num_heads=2,
273
+ num_layers=2,
274
+ key_size=5,
275
+ mlp_hidden_size=64,
276
+ dropout_rate=0.,
277
+ causal=False,
278
+ layer_norm=False)
279
+
280
+ @hk.without_apply_rng
281
+ @hk.transform
282
+ def forward_model(emb, mask):
283
+ return model.Transformer(config)(emb, mask).output
284
+
285
+ @hk.without_apply_rng
286
+ @hk.transform
287
+ def forward_superposition(emb, mask):
288
+ return compressed_model.CompressedTransformer(config)(
289
+ emb,
290
+ mask,
291
+ embedding_size=model_size,
292
+ unembed_at_every_layer=unembed_at_every_layer).output
293
+
294
+ seq_len = 4
295
+ emb = np.random.random((1, seq_len, model_size))
296
+ mask = np.ones((1, seq_len))
297
+ emb, mask = jnp.array(emb), jnp.array(mask)
298
+
299
+ rng = hk.PRNGSequence(1)
300
+ params = forward_model.init(next(rng), emb, mask)
301
+ params_superposition = {
302
+ k.replace("transformer", "compressed_transformer"): v
303
+ for k, v in params.items()
304
+ }
305
+ params_superposition["compressed_transformer"] = {
306
+ "w_emb": jnp.identity(model_size)
307
+ }
308
+
309
+ out_model = forward_model.apply(params, emb, mask)
310
+ out_superposition = forward_superposition.apply(params_superposition, emb,
311
+ mask)
312
+
313
+ self._check_layer_naming(params_superposition)
314
+ np.testing.assert_allclose(out_model, out_superposition)
315
+
316
+
317
+ if __name__ == "__main__":
318
+ absltest.main()
transformer/encoder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Basic encoder for inputs with a fixed vocabulary."""
16
+
17
+ import abc
18
+ from typing import Any, Sequence, Optional
19
+
20
+ from tracr.craft import bases
21
+
22
+
23
+ class Encoder(abc.ABC):
24
+ """Encodes a list of tokens into a list of inputs for a transformer model.
25
+
26
+ The abstract class does not make assumptions on the input and output types,
27
+ and we have different encoders for different input types.
28
+ """
29
+
30
+ @abc.abstractmethod
31
+ def encode(self, inputs: list[Any]) -> list[Any]:
32
+ return list()
33
+
34
+ @abc.abstractmethod
35
+ def decode(self, encodings: list[Any]) -> list[Any]:
36
+ return list()
37
+
38
+ @property
39
+ def pad_token(self) -> Optional[str]:
40
+ return None
41
+
42
+ @property
43
+ def bos_token(self) -> Optional[str]:
44
+ return None
45
+
46
+ @property
47
+ def pad_encoding(self) -> Optional[int]:
48
+ return None
49
+
50
+ @property
51
+ def bos_encoding(self) -> Optional[int]:
52
+ return None
53
+
54
+
55
+ class NumericalEncoder(Encoder):
56
+ """Encodes numerical variables (simply using the identity mapping)."""
57
+
58
+ def encode(self, inputs: list[float]) -> list[float]:
59
+ return inputs
60
+
61
+ def decode(self, encodings: list[float]) -> list[float]:
62
+ return encodings
63
+
64
+
65
+ class CategoricalEncoder(Encoder):
66
+ """Encodes categorical variables with a fixed vocabulary."""
67
+
68
+ def __init__(
69
+ self,
70
+ basis: Sequence[bases.BasisDirection],
71
+ enforce_bos: bool = False,
72
+ bos_token: Optional[str] = None,
73
+ pad_token: Optional[str] = None,
74
+ max_seq_len: Optional[int] = None,
75
+ ):
76
+ """Initialises. If enforce_bos is set, ensures inputs start with it."""
77
+ if enforce_bos and not bos_token:
78
+ raise ValueError("BOS token must be specified if enforcing BOS.")
79
+
80
+ self.encoding_map = {}
81
+ for i, direction in enumerate(basis):
82
+ val = direction.value
83
+ self.encoding_map[val] = i
84
+
85
+ if bos_token and bos_token not in self.encoding_map:
86
+ raise ValueError("BOS token missing in encoding.")
87
+
88
+ if pad_token and pad_token not in self.encoding_map:
89
+ raise ValueError("PAD token missing in encoding.")
90
+
91
+ self.enforce_bos = enforce_bos
92
+ self._bos_token = bos_token
93
+ self._pad_token = pad_token
94
+ self._max_seq_len = max_seq_len
95
+
96
+ def encode(self, inputs: list[bases.Value]) -> list[int]:
97
+ if self.enforce_bos and inputs[0] != self.bos_token:
98
+ raise ValueError("First input token must be BOS token. "
99
+ f"Should be '{self.bos_token}', but was '{inputs[0]}'.")
100
+ if missing := set(inputs) - set(self.encoding_map.keys()):
101
+ raise ValueError(f"Inputs {missing} not found in encoding ",
102
+ self.encoding_map.keys())
103
+ if self._max_seq_len is not None and len(inputs) > self._max_seq_len:
104
+ raise ValueError(f"{inputs=} are longer than the maximum "
105
+ f"sequence length {self._max_seq_len}")
106
+
107
+ return [self.encoding_map[x] for x in inputs]
108
+
109
+ def decode(self, encodings: list[int]) -> list[bases.Value]:
110
+ """Recover the tokens that corresponds to `ids`. Inverse of __call__."""
111
+ decoding_map = {val: key for key, val in self.encoding_map.items()}
112
+ if missing := set(encodings) - set(decoding_map.keys()):
113
+ raise ValueError(f"Inputs {missing} not found in decoding map ",
114
+ decoding_map.keys())
115
+ return [decoding_map[x] for x in encodings]
116
+
117
+ @property
118
+ def vocab_size(self) -> int:
119
+ return len(self.encoding_map)
120
+
121
+ @property
122
+ def bos_token(self) -> Optional[str]:
123
+ return self._bos_token
124
+
125
+ @property
126
+ def pad_token(self) -> Optional[str]:
127
+ return self._pad_token
128
+
129
+ @property
130
+ def bos_encoding(self) -> Optional[int]:
131
+ return None if self.bos_token is None else self.encoding_map[self.bos_token]
132
+
133
+ @property
134
+ def pad_encoding(self) -> Optional[int]:
135
+ return None if self.pad_token is None else self.encoding_map[self.pad_token]
transformer/encoder_test.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for transformer.encoder."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ from tracr.craft import bases
20
+ from tracr.transformer import encoder
21
+
22
+ _BOS_TOKEN = "bos_encoder_test"
23
+ _PAD_TOKEN = "pad_encoder_test"
24
+
25
+
26
+ class CategoricalEncoderTest(parameterized.TestCase):
27
+
28
+ def test_encode_raises_value_error_if_input_doesnt_start_with_bos(self):
29
+ vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN})
30
+ basic_encoder = encoder.CategoricalEncoder(
31
+ vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
32
+ with self.assertRaisesRegex(ValueError,
33
+ r"^.*First input token must be BOS token.*$"):
34
+ basic_encoder.encode([1, 1, 1])
35
+
36
+ def test_encode_raises_value_error_if_input_not_in_vocab(self):
37
+ vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN})
38
+ basic_encoder = encoder.CategoricalEncoder(
39
+ vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
40
+ with self.assertRaisesRegex(ValueError,
41
+ r"^.*Inputs .* not found in encoding.*$"):
42
+ basic_encoder.encode([_BOS_TOKEN, 1, 2, 3, 4])
43
+
44
+ def test_decode_raises_value_error_if_id_outside_of_vocab_size(self):
45
+ vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, _BOS_TOKEN})
46
+ basic_encoder = encoder.CategoricalEncoder(
47
+ vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
48
+ with self.assertRaisesRegex(ValueError,
49
+ r"^.*Inputs .* not found in decoding map.*$"):
50
+ basic_encoder.decode([0, 1, 2, 3])
51
+
52
+ def test_encoder_raises_value_error_if_bos_not_in_basis(self):
53
+ vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3})
54
+ with self.assertRaisesRegex(ValueError,
55
+ r"^.*BOS token missing in encoding.*$"):
56
+ unused_basic_encoder = encoder.CategoricalEncoder(
57
+ vs.basis, bos_token=_BOS_TOKEN)
58
+
59
+ def test_encoder_raises_value_error_if_pad_not_in_basis(self):
60
+ vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3})
61
+ with self.assertRaisesRegex(ValueError,
62
+ r"^.*PAD token missing in encoding.*$"):
63
+ unused_basic_encoder = encoder.CategoricalEncoder(
64
+ vs.basis, pad_token=_PAD_TOKEN)
65
+
66
+ def test_encoder_encodes_bos_and_pad_tokens_as_expected(self):
67
+ vs = bases.VectorSpaceWithBasis.from_values(
68
+ "input", {1, 2, 3, _BOS_TOKEN, _PAD_TOKEN})
69
+ basic_encoder = encoder.CategoricalEncoder(
70
+ vs.basis, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN)
71
+ self.assertEqual(
72
+ basic_encoder.encode([_BOS_TOKEN, _PAD_TOKEN]),
73
+ [basic_encoder.bos_encoding, basic_encoder.pad_encoding])
74
+
75
+ @parameterized.parameters([
76
+ dict(
77
+ vocab={1, 2, 3, _BOS_TOKEN}, # lexicographic order
78
+ inputs=[_BOS_TOKEN, 3, 2, 1],
79
+ expected=[3, 2, 1, 0]),
80
+ dict(
81
+ vocab={"a", "b", _BOS_TOKEN, "c"}, # lexicographic order
82
+ inputs=[_BOS_TOKEN, "b", "b", "c"],
83
+ expected=[2, 1, 1, 3]),
84
+ ])
85
+ def test_tokens_are_encoded_in_lexicographic_order(self, vocab, inputs,
86
+ expected):
87
+ # Expect encodings to be assigned to ids according to a lexicographic
88
+ # ordering of the vocab
89
+ vs = bases.VectorSpaceWithBasis.from_values("input", vocab)
90
+ basic_encoder = encoder.CategoricalEncoder(
91
+ vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
92
+ encodings = basic_encoder.encode(inputs)
93
+ self.assertEqual(encodings, expected)
94
+
95
+ @parameterized.parameters([
96
+ dict(vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, expected=5),
97
+ dict(vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b"}, expected=4),
98
+ ])
99
+ def test_vocab_size_has_expected_value(self, vocab, expected):
100
+ vs = bases.VectorSpaceWithBasis.from_values("input", vocab)
101
+ basic_encoder = encoder.CategoricalEncoder(
102
+ vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN)
103
+ self.assertEqual(basic_encoder.vocab_size, expected)
104
+
105
+ @parameterized.parameters([
106
+ dict(
107
+ vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, inputs=[_BOS_TOKEN, 3, 2,
108
+ 1]),
109
+ dict(
110
+ vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b", "c"},
111
+ inputs=[_BOS_TOKEN, "b", "b", "c"]),
112
+ ])
113
+ def test_decode_inverts_encode(self, vocab, inputs):
114
+ vs = bases.VectorSpaceWithBasis.from_values("input", vocab)
115
+ basic_encoder = encoder.CategoricalEncoder(
116
+ vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN)
117
+ encodings = basic_encoder.encode(inputs)
118
+ recovered = basic_encoder.decode(encodings)
119
+ self.assertEqual(recovered, inputs)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ absltest.main()
transformer/model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Didactic example of an autoregressive Transformer-based language model.
16
+
17
+ Glossary of shapes:
18
+ - B: Batch size.
19
+ - T: Sequence length.
20
+ - D: Model embedding size.
21
+ - H: Number of attention heads.
22
+ - V: Vocabulary size.
23
+
24
+ Forked from: haiku.examples.transformer.model
25
+ """
26
+
27
+ import collections
28
+ import dataclasses
29
+ from typing import Callable, Optional
30
+
31
+ import chex
32
+ import haiku as hk
33
+ import jax
34
+ import jax.numpy as jnp
35
+ import numpy as np
36
+ from tracr.transformer import attention
37
+
38
+ # hk.Modules are not always callable: github.com/deepmind/dm-haiku/issues/52
39
+ # Ideally, we'd want a type:
40
+ # CallableHaikuModule = Intersection[Callable[..., jax.Array], hk.Module]
41
+ # But Intersection does not exist (yet): github.com/python/typing/issues/213
42
+ CallableHaikuModule = Callable[..., jax.Array]
43
+
44
+
45
+ @chex.dataclass
46
+ class TransformerOutput:
47
+ layer_outputs: list[jax.Array] # [B, T, D]
48
+ residuals: list[jax.Array] # [B, T, D]
49
+ attn_logits: list[jax.Array] # [B, H, T, T]
50
+ output: jax.Array # [B, T, D]
51
+ input_embeddings: jax.Array # [B, T, D]
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class TransformerConfig:
56
+ num_heads: int
57
+ num_layers: int
58
+ key_size: int
59
+ mlp_hidden_size: int
60
+ dropout_rate: float
61
+ activation_function: Callable[[jax.Array], jax.Array] = jax.nn.gelu
62
+ layer_norm: bool = True
63
+ causal: bool = False
64
+
65
+
66
+ @dataclasses.dataclass
67
+ class Transformer(hk.Module):
68
+ """A transformer stack."""
69
+
70
+ config: TransformerConfig
71
+ name: Optional[str] = None
72
+
73
+ def __call__(
74
+ self,
75
+ embeddings: jax.Array, # [B, T, D]
76
+ mask: jax.Array, # [B, T]
77
+ *,
78
+ use_dropout: bool = True,
79
+ ) -> TransformerOutput:
80
+ """Transforms input embedding sequences to output embedding sequences."""
81
+
82
+ def layer_norm(x: jax.Array) -> jax.Array:
83
+ """Applies a unique LayerNorm to x with default settings."""
84
+ if self.config.layer_norm:
85
+ return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
86
+ return x
87
+
88
+ initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers)
89
+ dropout_rate = self.config.dropout_rate if use_dropout else 0.
90
+ _, seq_len, model_size = embeddings.shape
91
+
92
+ # Compute causal mask for autoregressive sequence modelling.
93
+ mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
94
+ mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T]
95
+
96
+ if self.config.causal:
97
+ causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T]
98
+ causal_mask = np.tril(causal_mask)
99
+ mask = mask * causal_mask # [B, H=1, T, T]
100
+
101
+ # Set up activation collection.
102
+ collected = collections.defaultdict(list)
103
+
104
+ def collect(**kwargs):
105
+ for k, v in kwargs.items():
106
+ collected[k].append(v)
107
+
108
+ residual = embeddings
109
+ for layer in range(self.config.num_layers):
110
+ with hk.experimental.name_scope(f"layer_{layer}"):
111
+ # First the attention block.
112
+ attn_block = attention.MultiHeadAttention(
113
+ num_heads=self.config.num_heads,
114
+ key_size=self.config.key_size,
115
+ model_size=model_size,
116
+ w_init=initializer,
117
+ name="attn")
118
+ attn_in = layer_norm(residual)
119
+ attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask)
120
+ attn_out, attn_logits = attn_out.out, attn_out.logits
121
+ if dropout_rate > 0:
122
+ attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out)
123
+ residual = residual + attn_out
124
+
125
+ collect(
126
+ residuals=residual, layer_outputs=attn_out, attn_logits=attn_logits)
127
+
128
+ # Then the dense block.
129
+ with hk.experimental.name_scope("mlp"):
130
+ dense_block = hk.Sequential([
131
+ hk.Linear(
132
+ self.config.mlp_hidden_size,
133
+ w_init=initializer,
134
+ name="linear_1"),
135
+ self.config.activation_function,
136
+ hk.Linear(model_size, w_init=initializer, name="linear_2"),
137
+ ])
138
+ dense_in = layer_norm(residual)
139
+ dense_out = dense_block(dense_in)
140
+ if dropout_rate > 0:
141
+ dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out)
142
+ residual = residual + dense_out
143
+
144
+ collect(residuals=residual, layer_outputs=dense_out)
145
+
146
+ return TransformerOutput(
147
+ residuals=collected["residuals"],
148
+ layer_outputs=collected["layer_outputs"],
149
+ attn_logits=collected["attn_logits"],
150
+ output=layer_norm(residual),
151
+ input_embeddings=embeddings,
152
+ )
153
+
154
+
155
+ @chex.dataclass
156
+ class CompiledTransformerModelOutput:
157
+ transformer_output: TransformerOutput
158
+ unembedded_output: jax.Array # [B, T]
159
+
160
+
161
+ @dataclasses.dataclass
162
+ class CompiledTransformerModel(hk.Module):
163
+ """A transformer model with one-hot embeddings."""
164
+ transformer: Transformer
165
+ token_embed: CallableHaikuModule
166
+ position_embed: CallableHaikuModule
167
+ unembed: CallableHaikuModule
168
+ use_unembed_argmax: bool
169
+ pad_token: Optional[int] = None
170
+
171
+ def embed(self, tokens: jax.Array) -> jax.Array:
172
+ token_embeddings = self.token_embed(tokens)
173
+ positional_embeddings = self.position_embed(jnp.indices(tokens.shape)[-1])
174
+ return token_embeddings + positional_embeddings # [B, T, D]
175
+
176
+ def __call__(
177
+ self,
178
+ tokens: jax.Array,
179
+ use_dropout: bool = True,
180
+ ) -> CompiledTransformerModelOutput:
181
+ """Embed tokens, pass through model, and unembed output."""
182
+ if self.pad_token is None:
183
+ input_mask = jnp.ones_like(tokens)
184
+ else:
185
+ input_mask = (tokens != self.pad_token)
186
+ input_embeddings = self.embed(tokens)
187
+
188
+ transformer_output = self.transformer(
189
+ input_embeddings,
190
+ input_mask,
191
+ use_dropout=use_dropout,
192
+ )
193
+ return CompiledTransformerModelOutput(
194
+ transformer_output=transformer_output,
195
+ unembedded_output=self.unembed(
196
+ transformer_output.output,
197
+ use_unembed_argmax=self.use_unembed_argmax,
198
+ ),
199
+ )
transformer/model_test.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for transformer.model."""
16
+
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import haiku as hk
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from tracr.transformer import model
24
+
25
+
26
+ class TransformerTest(parameterized.TestCase):
27
+
28
+ def _check_layer_naming(self, params):
29
+ # Modules should be named for example
30
+ # For MLPs: "transformer/layer_{i}/mlp/linear_1"
31
+ # For Attention: "transformer/layer_{i}/attn/key"
32
+ # For Layer Norm: "transformer/layer_{i}/layer_norm"
33
+ for key in params.keys():
34
+ levels = key.split("/")
35
+ self.assertEqual(levels[0], "transformer")
36
+ if levels[1].startswith("layer_norm"):
37
+ continue # output layer norm
38
+ self.assertStartsWith(levels[1], "layer")
39
+ if levels[2] == "mlp":
40
+ self.assertIn(levels[3], {"linear_1", "linear_2"})
41
+ elif levels[2] == "attn":
42
+ self.assertIn(levels[3], {"key", "query", "value", "linear"})
43
+ else:
44
+ self.assertStartsWith(levels[2], "layer_norm")
45
+
46
+ def _zero_mlps(self, params):
47
+ for module in params:
48
+ if "mlp" in module:
49
+ for param in params[module]:
50
+ params[module][param] = jnp.zeros_like(params[module][param])
51
+ return params
52
+
53
+ @parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False))
54
+ def test_layer_norm(self, layer_norm):
55
+ # input = [1, 1, 1, 1]
56
+ # If layer norm is used, this should give all-0 output for a freshly
57
+ # initialized model because LN will subtract the mean after each layer.
58
+ # Else we expect non-zero outputs.
59
+
60
+ @hk.transform
61
+ def forward(emb, mask):
62
+ transformer = model.Transformer(
63
+ model.TransformerConfig(
64
+ num_heads=2,
65
+ num_layers=2,
66
+ key_size=5,
67
+ mlp_hidden_size=64,
68
+ dropout_rate=0.,
69
+ layer_norm=layer_norm))
70
+ return transformer(emb, mask).output
71
+
72
+ seq_len = 4
73
+ emb = jnp.ones((1, seq_len, 1))
74
+ mask = jnp.ones((1, seq_len))
75
+ rng = hk.PRNGSequence(1)
76
+ params = forward.init(next(rng), emb, mask)
77
+ out = forward.apply(params, next(rng), emb, mask)
78
+
79
+ self._check_layer_naming(params)
80
+ if layer_norm:
81
+ np.testing.assert_allclose(out, 0)
82
+ else:
83
+ self.assertFalse(np.allclose(out, 0))
84
+
85
+ @parameterized.parameters(dict(causal=True), dict(causal=False))
86
+ def test_causal_attention(self, causal):
87
+ # input = [0, random, random, random]
88
+ # mask = [1, 0, 1, 1]
89
+ # For causal attention the second token can only attend to the first one, so
90
+ # it should be the same. For non-causal attention all tokens should change.
91
+
92
+ @hk.transform
93
+ def forward(emb, mask):
94
+ transformer = model.Transformer(
95
+ model.TransformerConfig(
96
+ num_heads=2,
97
+ num_layers=2,
98
+ key_size=5,
99
+ mlp_hidden_size=64,
100
+ dropout_rate=0.,
101
+ layer_norm=False,
102
+ causal=causal))
103
+ return transformer(emb, mask).output
104
+
105
+ seq_len = 4
106
+ emb = np.random.random((1, seq_len, 1))
107
+ emb[:, 0, :] = 0
108
+ mask = np.array([[1, 0, 1, 1]])
109
+ emb, mask = jnp.array(emb), jnp.array(mask)
110
+
111
+ rng = hk.PRNGSequence(1)
112
+ params = forward.init(next(rng), emb, mask)
113
+ params = self._zero_mlps(params)
114
+ out = forward.apply(params, next(rng), emb, mask)
115
+
116
+ self._check_layer_naming(params)
117
+ if causal:
118
+ self.assertEqual(0, out[0, 0, 0])
119
+ self.assertEqual(emb[0, 1, 0], out[0, 1, 0])
120
+ else:
121
+ self.assertNotEqual(0, out[0, 0, 0])
122
+ self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0])
123
+ self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0])
124
+ self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0])
125
+
126
+ def test_setting_activation_function_to_zero(self):
127
+ # An activation function that always returns zeros should result in the
128
+ # same model output as setting all MLP weights to zero.
129
+
130
+ @hk.transform
131
+ def forward_zero(emb, mask):
132
+ transformer = model.Transformer(
133
+ model.TransformerConfig(
134
+ num_heads=2,
135
+ num_layers=2,
136
+ key_size=5,
137
+ mlp_hidden_size=64,
138
+ dropout_rate=0.,
139
+ causal=False,
140
+ layer_norm=False,
141
+ activation_function=jnp.zeros_like))
142
+ return transformer(emb, mask).output
143
+
144
+ @hk.transform
145
+ def forward(emb, mask):
146
+ transformer = model.Transformer(
147
+ model.TransformerConfig(
148
+ num_heads=2,
149
+ num_layers=2,
150
+ key_size=5,
151
+ mlp_hidden_size=64,
152
+ dropout_rate=0.,
153
+ causal=False,
154
+ layer_norm=False,
155
+ activation_function=jax.nn.gelu))
156
+ return transformer(emb, mask).output
157
+
158
+ seq_len = 4
159
+ emb = np.random.random((1, seq_len, 1))
160
+ mask = np.ones((1, seq_len))
161
+ emb, mask = jnp.array(emb), jnp.array(mask)
162
+
163
+ rng = hk.PRNGSequence(1)
164
+ params = forward.init(next(rng), emb, mask)
165
+ params_no_mlps = self._zero_mlps(params)
166
+
167
+ out_zero_activation = forward_zero.apply(params, next(rng), emb, mask)
168
+ out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask)
169
+
170
+ self._check_layer_naming(params)
171
+ np.testing.assert_allclose(out_zero_activation, out_no_mlps)
172
+ self.assertFalse(np.allclose(out_zero_activation, 0))
173
+
174
+
175
+ class CompiledTransformerModelTest(parameterized.TestCase):
176
+
177
+ def _get_one_hot_embed_unembed(self, vocab_size, max_seq_len):
178
+ # Embeds tokens as one-hot into the first `vocab_size` dimensions
179
+ token_embed = hk.Embed(
180
+ embedding_matrix=jnp.block(
181
+ [jnp.eye(vocab_size),
182
+ jnp.zeros((vocab_size, max_seq_len))]))
183
+
184
+ # Embeds positions as one-hot into the last `max_seq_len` dimensions
185
+ position_embed = hk.Embed(
186
+ embedding_matrix=jnp.block(
187
+ [jnp.zeros((max_seq_len, vocab_size)),
188
+ jnp.eye(max_seq_len)]))
189
+
190
+ class Unembed(hk.Module):
191
+
192
+ def __call__(self, embeddings):
193
+ return jnp.argmax(embeddings[:, :, :vocab_size], axis=-1)
194
+
195
+ return token_embed, position_embed, Unembed()
196
+
197
+ def test_embedding_gives_desired_result(self):
198
+ tokens = jnp.array([[1, 2, 3]])
199
+ vocab_size, max_seq_len, pad_token = 5, 5, 0
200
+
201
+ expected_embeddings = jnp.array([[[0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
202
+ [0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
203
+ [0, 0, 0, 1, 0, 0, 0, 1, 0, 0]]])
204
+
205
+ @hk.transform
206
+ def embed(tokens):
207
+ transformer = model.Transformer(
208
+ model.TransformerConfig(
209
+ num_heads=2,
210
+ num_layers=2,
211
+ key_size=5,
212
+ mlp_hidden_size=64,
213
+ dropout_rate=0.,
214
+ causal=False,
215
+ layer_norm=False,
216
+ activation_function=jax.nn.gelu))
217
+ token_embed, position_embed, unembed = self._get_one_hot_embed_unembed(
218
+ vocab_size, max_seq_len)
219
+ compiled_model = model.CompiledTransformerModel(
220
+ transformer=transformer,
221
+ token_embed=token_embed,
222
+ position_embed=position_embed,
223
+ unembed=unembed,
224
+ use_unembed_argmax=True,
225
+ pad_token=pad_token)
226
+ return compiled_model.embed(tokens)
227
+
228
+ rng = hk.PRNGSequence(1)
229
+ params = embed.init(next(rng), tokens)
230
+ embeddings = embed.apply(params, next(rng), tokens)
231
+
232
+ np.testing.assert_allclose(embeddings, expected_embeddings)
233
+
234
+ def test_embedding_then_unembedding_gives_same_tokens(self):
235
+ tokens = jnp.array([[1, 2, 3], [4, 5, 6], [3, 2, 4]])
236
+ vocab_size, max_seq_len, pad_token = 10, 5, 0
237
+
238
+ @hk.transform
239
+ def embed_unembed(tokens):
240
+ transformer = model.Transformer(
241
+ model.TransformerConfig(
242
+ num_heads=2,
243
+ num_layers=2,
244
+ key_size=5,
245
+ mlp_hidden_size=64,
246
+ dropout_rate=0.,
247
+ causal=False,
248
+ layer_norm=False,
249
+ activation_function=jax.nn.gelu))
250
+ token_embed, position_embed, unembed = self._get_one_hot_embed_unembed(
251
+ vocab_size, max_seq_len)
252
+ compiled_model = model.CompiledTransformerModel(
253
+ transformer=transformer,
254
+ token_embed=token_embed,
255
+ position_embed=position_embed,
256
+ unembed=unembed,
257
+ use_unembed_argmax=True,
258
+ pad_token=pad_token)
259
+ embeddings = compiled_model.embed(tokens)
260
+ unembeddings = compiled_model.unembed(embeddings)
261
+ return embeddings, unembeddings
262
+
263
+ rng = hk.PRNGSequence(1)
264
+ params = embed_unembed.init(next(rng), tokens)
265
+ embeddings, unembeddings = embed_unembed.apply(params, next(rng), tokens)
266
+
267
+ self.assertEqual(
268
+ embeddings.shape,
269
+ (tokens.shape[0], tokens.shape[1], vocab_size + max_seq_len))
270
+
271
+ np.testing.assert_allclose(unembeddings, tokens)
272
+
273
+
274
+ if __name__ == "__main__":
275
+ absltest.main()
utils/debugging.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Useful helpers for model debugging."""
16
+
17
+
18
+ def print_arrays(arrays, labels=None, colwidth=12):
19
+ """Pretty-prints a list of [1, T, D] arrays."""
20
+ if labels is not None:
21
+ print(" |".join(labels))
22
+ widths = [len(l) for l in labels]
23
+ else:
24
+ widths = [colwidth] * len(arrays[0].shape[-1])
25
+ for layer in arrays:
26
+ print("=" * (colwidth + 1) * layer.shape[1])
27
+ for row in layer[0]:
28
+ print(" |".join([f"{x:<{width}.2f}" for x, width in zip(row, widths)]))