diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..dc3f8f1f9d192a5d084fbc1f626fbdb4a068fe2d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We welcome your contributions to this project. Please read the guidance below +first. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google/conduct/). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index b34ef22a7ecc59245a95c9c5a28a820ce3d49b59..25728ddc50ac4b35ac5b2b9ca9cd9c125cb883ee 100644 --- a/README.md +++ b/README.md @@ -1 +1,208 @@ -# tracr +# Tracr: TRAnsformer Compiler for RASP. + +Tracr is a compiler for converting RASP programs +([Weiss et al. 2021](https://arxiv.org/abs/2106.06981)) +into transformer weights. + +Directory structure: + +* `rasp` contains an implementation of RASP embedded in Python. +* `compiler` contains the compiler itself. +* `transformer` contains the implementation of the transformer. +* `craft` contains the intermediate representation used by the compiler: + essentially a small linear algebra-based library with named dimensions. + +This is not an officially supported Google product. + + +## Installation + +Installation is currently a bit manual. First, install dependencies: + +``` +pip3 install chex einops dm-haiku networkx +``` + +Second, clone the repo: + +``` +git clone https://github.com/deepmind/tracr +``` + +Third, put the resulting folder somewhere in your `PYTHONPATH` +(eg by placing the `tracr` checkout in the root of your project folder). + +This will be made easier in the future. + + +## Usage example: RASP `reverse` program + +Consider the RASP `reverse` program: + +``` +opp_index = length - indices - 1; +flip = select(indices, opp_index, ==); +reverse = aggregate(flip, tokens); +``` + +To compile this with Tracr, we would first implement the program using Tracr's +RASP library: + +```python +from tracr.rasp import rasp + +length = make_length() # `length` is not a primitive in our implementation. +opp_index = length - rasp.indices - 1 +flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ) +reverse = rasp.Aggregate(flip, rasp.tokens) +``` + +Where: + +```python +def make_length(): + all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE) + return rasp.SelectorWidth(all_true_selector) +``` + +We can then compile the RASP program to a transformer with: + +```python +from tracr.compiler import compiling + +bos = "BOS" +model = compiling.compile_rasp_to_model( + reverse, + vocab={1, 2, 3}, + max_seq_len=5, + compiler_bos=bos, +) +``` + +This yields a transformer as a [Haiku](https://github.com/deepmind/dm-haiku) model. +This model isn't intended to provide _everything_ you might need, but rather serves +as a kind of "documentation-in-code" for the semantics of the generated parameters. +The expectation is that the user can then write or contribute an adapter that converts +parameters from this reference model to another transformer implementation. + +Using this model we can perform a forward pass: + +```python +>>> out = model.apply([bos, 1, 2, 3]) +>>> out.decoded +["BOS", 3, 2, 1] +``` + +Success! We have a transformer that reverses its input tokens. + +Note: compiled models always expect a BOS token in order to support +selectors which don't attend to any of the input tokens. This is necessary to +preserve intuitive RASP semantics; the alternative would have been to treat +all-False selector rows as equivalent to all-True (which is what softmax in an +attention layer would naturally do). For more details, see our paper. + +You can also inspect some of the intermediate activations of the model, using +`out.residuals`, `out.layer_outputs`, and `out.attn_logits`. + +For more examples of RASP programs we can compile, check out +[compiler/lib.py](compiler/lib.py). + +For an interactive example of compiling a model and visualizing its computation, +check out the notebook at +[examples/Visualize\_Tracr\_Models.ipynb](examples/Visualize_Tracr_Models.ipynb). + + +## Developer README + +If you'd like to extend Tracr to fit your purposes, here's some information on +how Tracr works under the hood. + + +### How Tracr works conceptually + +To compile a program, Tracr does the following. + +1. **Trace RASP program into a graph representation.** This involves creating + a graph node for each RASP expression and inferring dependencies between + these graph nodes. + +2. **Infer bases.** Tracr is designed to have each node output to a separate + subspace of the residual stream. To do this, we first infer the set of all + possible token values that each node can take, then using that information, + decide on a subspace for each node, and augment each node in the graph + with the basis vectors for that node's subspace. + +3. **Convert nodes to Craft components.** Craft is the name of our internal + intermediate representation that does linear algebra on named subspaces. In + this stage, each expression node is converted to a Craft component that + actually performs the linear algebra operations necessary to implement the + expression. This includes converting _sequence operators_ to MLP weights, + and _selectors_ to weights of attention heads. (We compute the appropriate + weights directly using the theory of universal approximation for MLPs - no + gradient descent required!) + +4. **Convert Craft graph to Craft model.** In this stage, we convert from + a graph representation to a layout that looks more like an actual + transformer. At this stage, we essentially have a working model, but + with the linear algebra done using Craft rather than JAX + Haiku. + +5. **Convert Craft model to Haiku model.** Finally, we convert our + intermediate representation of the model to a full Haiku model. + +Two details worth expanding on here are subspaces and corresponding bases. +Each node writes to a separate subspace of the residual stream, +where each subspace is simply a unique chunk of the residual stream vector. +For example, the first node might write to the first 5 components of +the residual stream; the second node the next 5; and so on. In terms of what +the embeddings actually associated with each node, Tracr employs two +different kinds of bases: + +* **Categorical representation** - in which each unique token value is + represented as a unique one-hot vector in that node's subspace. This + is the representation used by default. +* **Numerical representation** - in which each unique token value is + mapped to a unique scalar value. This is necessary for some uses + of the `aggregate` operation - essentially, ones which involve taking + a mean - and some other operations are represented more efficiently + with this representation. + +A final detail is BOS tokens. The compiler relies on beginning-of-sequence +tokens to in order to implement a number of operations. This is why token +sequences fed into the final model _must_ start with a BOS token. + + +### How Tracr works in practice + +The flow of compilation execution begins in +[`compiler/compiling.py`](compiler/compiling.py), in the +`compile_rasp_to_model` function. This function is fairly short and maps +directly to the stages outlined above, so don't be afraid to read the source! + + +## Running tests + +We use [`absltest`](https://abseil.io/docs/python/guides/testing), which is +`unittest`-compatible, and is therefore in turn `pytest`-compatible. + +First, install test dependencies: + +``` +pip3 install absl-py pytest +``` + +``` +# We use `python3 -m pytest` instead of just `pytest` so that the working directory is +# added to PYTHONPATH. +# -ra: Report names of tests that failed, were skipped, etc. +python3 -m pytest -ra +``` + +This should take about 60 seconds. If you install `pytest-xdist`, you can run them in +parallel with: + +``` +python3 -m pytest -ra -n auto +``` + +However, currently this only shaves off about 10 seconds, since it's bottlenecked by a +single long-running test. diff --git a/compiler/__init__.py b/compiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7157bdbb8716be11e104a08b445fce23fc7953c --- /dev/null +++ b/compiler/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Provides the main compiler function as a public import.""" + +from tracr.compiler.compiling import compile_rasp_to_model + +__all__ = ["compile_rasp_to_model"] diff --git a/compiler/assemble.py b/compiler/assemble.py new file mode 100644 index 0000000000000000000000000000000000000000..c37ac52fa4bbb9f4ae65dd217331719b2795da4f --- /dev/null +++ b/compiler/assemble.py @@ -0,0 +1,335 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Assemble weights of a transformer model from a craft residual stack.""" + +import dataclasses +from typing import Any, Callable, Optional, Protocol + +import chex +import einops +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from tracr.craft import bases +from tracr.craft import transformers +from tracr.craft import vectorspace_fns +from tracr.transformer import encoder +from tracr.transformer import model + + +@chex.dataclass +class AssembledTransformerModelOutput: + decoded: list[Any] # length T. + unembedded: jax.Array # [B, T] B = 1 always. + layer_outputs: list[jax.Array] # [B, T, D] + residuals: list[jax.Array] # [B, T, D] + attn_logits: list[jax.Array] # [B, T, T, H] + transformer_output: jax.Array # [B, T, D] + input_embeddings: jax.Array + + +class ModelForward(Protocol): + + def __call__( + self, + params: hk.Params, + emb: jax.Array, + ) -> model.CompiledTransformerModelOutput: + """A hk-transformed forward pass through the compiled model.""" + + +@dataclasses.dataclass +class AssembledTransformerModel: + """Model architecture and parameters from assembling a model.""" + forward: ModelForward + get_compiled_model: Callable[[], model.CompiledTransformerModel] + params: hk.Params + model_config: model.TransformerConfig + residual_labels: list[str] + input_encoder: Optional[encoder.Encoder] = None + output_encoder: Optional[encoder.Encoder] = None + + def apply(self, tokens: list[bases.Value]) -> AssembledTransformerModelOutput: + """Returns output from running the model on a set of input tokens.""" + if self.input_encoder: + tokens = self.input_encoder.encode(tokens) + tokens = jnp.array([tokens]) + output = self.forward(self.params, tokens) + decoded = output.unembedded_output[0].tolist() + if self.output_encoder: + decoded = self.output_encoder.decode(decoded) + + if self.input_encoder.bos_token: + # Special case for decoding the bos token position, for which the output + # decoder might have unspecified behavior. + decoded = [self.input_encoder.bos_token] + decoded[1:] + + return AssembledTransformerModelOutput( + decoded=decoded, + unembedded=output.unembedded_output, + layer_outputs=output.transformer_output.layer_outputs, + residuals=output.transformer_output.residuals, + attn_logits=output.transformer_output.attn_logits, + transformer_output=output.transformer_output.output, + input_embeddings=output.transformer_output.input_embeddings) + + +@dataclasses.dataclass +class EmbeddingModules: + """Modules for embedding and tokens and positions and unembedding results.""" + token_embed: model.CallableHaikuModule + pos_embed: model.CallableHaikuModule + unembed: model.CallableHaikuModule + + +def _get_model_config_and_module_names( + craft_model: transformers.SeriesWithResiduals +) -> tuple[model.TransformerConfig, list[str]]: + """Returns model config and locations (in params) for halflayers.""" + + multi_attn_heads: list[list[transformers.AttentionHead]] = [] + mlps: list[transformers.MLP] = [] + module_names: list[str] = [] + + candidate_module_names = [] + for layer in range(len(craft_model.blocks)): + candidate_module_names.append(f"transformer/layer_{layer}/attn") + candidate_module_names.append(f"transformer/layer_{layer}/mlp") + candidate_module_names = iter(candidate_module_names) + + for module in craft_model.blocks: + if isinstance(module, transformers.MLP): + mlps.append(module) + layer_type = "mlp" + else: + multi_attn_heads.append(list(module.as_multi().heads())) + layer_type = "attn" + # Find next layer with the necessary type. Modules in-between, that are not + # added to module_names will be disabled later by setting all weights to 0. + module_name = next(candidate_module_names) + while layer_type not in module_name: + module_name = next(candidate_module_names) + module_names.append(module_name) + + num_layers = int(module_names[-1].split("_")[1].split("/")[0]) + 1 + heads = sum(multi_attn_heads, []) + + if multi_attn_heads: + num_heads = max(len(heads) for heads in multi_attn_heads) + key_size = max(max(head.w_qk.matrix.shape) for head in heads) + else: + num_heads, key_size = 1, 1 + + if mlps: + mlp_hidden_size = max(mlp.fst.output_space.num_dims for mlp in mlps) + else: + mlp_hidden_size = 1 + + model_config = model.TransformerConfig( + num_heads=num_heads, + num_layers=num_layers, + key_size=key_size, + mlp_hidden_size=mlp_hidden_size, + dropout_rate=0., + activation_function=jax.nn.relu, + layer_norm=False, + causal=False, + ) + + return model_config, module_names + + +def _make_embedding_modules( + residual_space: bases.VectorSpaceWithBasis, + tokens_space: bases.VectorSpaceWithBasis, + indices_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis) -> EmbeddingModules: + """Creates embedding and unembedding modules from vector spaces. + + Args: + residual_space: Full residual space of the model. + tokens_space: Subspace to embed tokens to. + indices_space: Subspace to embed indices/position embeddings to. + output_space: Subspace to unembed outputs from. + + Returns: + EmbeddingModules containing modules for token embeddings, position + embeddings and unembeddings. + """ + tokens_to_res = vectorspace_fns.project(tokens_space, residual_space) + + # If we use the 'one' direction, make sure all inputs have a 1 here + one_dir = bases.BasisDirection("one") + if one_dir in residual_space: + one_to_res = vectorspace_fns.Linear.from_action( + tokens_space, residual_space, + lambda x: residual_space.vector_from_basis_direction(one_dir)) + tokens_to_res = vectorspace_fns.Linear.combine_in_parallel( + [tokens_to_res, one_to_res]) + + # Token embeddings. + res_to_out = vectorspace_fns.project(residual_space, output_space) + token_embed = hk.Embed( + embedding_matrix=tokens_to_res.matrix, name="token_embed") + + # Positional embeddings. + index_to_res = vectorspace_fns.project(indices_space, residual_space) + # The zeroth position should not have any positional embeddings, + # so we add one line of padding at the zeroth position. + pos_matrix = np.concatenate( + [np.zeros((1, residual_space.num_dims)), index_to_res.matrix], axis=0) + pos_embed = hk.Embed(embedding_matrix=pos_matrix, name="pos_embed") + + def unembed(x, use_unembed_argmax): + out = x @ res_to_out.matrix + if use_unembed_argmax: + return jnp.argmax(out, axis=-1) + elif out.shape[-1] == 1: + return out.squeeze(-1) + return out + + unembed_mod = hk.to_module(unembed)() + return EmbeddingModules( + token_embed=token_embed, pos_embed=pos_embed, unembed=unembed_mod) + + +def assemble_craft_model( + craft_model: transformers.SeriesWithResiduals, + tokens_space: bases.VectorSpaceWithBasis, + indices_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + categorical_output: bool, + causal: bool = False, +) -> AssembledTransformerModel: + """Assembles the given components into a Haiku model with parameters. + + Args: + craft_model: Model to assemble weights for. + tokens_space: Vectorspace to embed the input tokens to. + indices_space: Vectorspace to embed the indices to (position encodings). + output_space: Vectorspace that the model will write outputs to that should + be unembedded. + categorical_output: Whether the output is categorical. If True, we take an + argmax when unembedding. + causal: Whether to output a causally-masked model. + + Returns: + An AssembledTransformerModel that contains the model and parameters of the + assembled transformer. + """ + # TODO(b/255936413): Make embeddings only retain the tokens and indices that + # are actually used. + # TODO(b/255936496): Think about enabling layer norm and reversing it somehow + + model_config, module_names = _get_model_config_and_module_names(craft_model) + model_config.causal = causal + + residual_space = bases.join_vector_spaces(craft_model.residual_space, + tokens_space, indices_space, + output_space) + residual_labels = [str(basis_dir) for basis_dir in residual_space.basis] + + # Build model with embedding and unembedding layers + def get_compiled_model(): + transformer = model.Transformer(model_config) + embed_modules = _make_embedding_modules( + residual_space=residual_space, + tokens_space=tokens_space, + indices_space=indices_space, + output_space=output_space) + return model.CompiledTransformerModel( + transformer=transformer, + token_embed=embed_modules.token_embed, + position_embed=embed_modules.pos_embed, + unembed=embed_modules.unembed, + use_unembed_argmax=categorical_output) + + @hk.without_apply_rng + @hk.transform + def forward(emb): + compiled_model = get_compiled_model() + return compiled_model(emb, use_dropout=False) + + params = forward.init(jax.random.PRNGKey(0), jnp.array([[1, 2, 3]])) + + for key in params: + if "transformer" in key: + for par in params[key]: + params[key][par] = np.zeros_like(params[key][par]) + + # Assemble attention and MLP weights. + project = lambda space: vectorspace_fns.project(residual_space, space).matrix + + for module_name, module in zip(module_names, craft_model.blocks): + if isinstance(module, transformers.MLP): + hidden_size = module.fst.output_space.num_dims + residual_to_fst_input = project(module.fst.input_space) + snd_output_to_residual = project(module.snd.output_space).T + params[f"{module_name}/linear_1"]["w"][:, :hidden_size] = ( + residual_to_fst_input @ module.fst.matrix) + params[f"{module_name}/linear_2"]["w"][:hidden_size, :] = ( + module.snd.matrix @ snd_output_to_residual) + else: # Attention module + query, key, value, linear = [], [], [], [] + for head in module.as_multi().heads(): + key_size = head.w_qk.matrix.shape[1] + query_mat = np.zeros((residual_space.num_dims, model_config.key_size)) + residual_to_query = project(head.w_qk.left_space) + query_mat[:, :key_size] = residual_to_query @ head.w_qk.matrix + query.append(query_mat) + + key_mat = np.zeros((residual_space.num_dims, model_config.key_size)) + key_mat[:, :key_size] = project(head.w_qk.right_space) + key.append(key_mat) + + value_size = head.w_ov.matrix.shape[1] + value_mat = np.zeros((residual_space.num_dims, model_config.key_size)) + residual_to_ov_input = project(head.w_ov.input_space) + value_mat[:, :value_size] = residual_to_ov_input @ head.w_ov.matrix + value.append(value_mat) + + linear_mat = np.zeros((model_config.key_size, residual_space.num_dims)) + linear_mat[:value_size, :] = project(head.w_ov.output_space).T + linear.append(linear_mat) + + # Fill up heads that are not used with zero weights + for _ in range(model_config.num_heads - module.as_multi().num_heads): + query.append(np.zeros_like(query[0])) + key.append(np.zeros_like(key[0])) + value.append(np.zeros_like(value[0])) + linear.append(np.zeros_like(linear[0])) + + query = einops.rearrange(query, + "heads input output -> input (heads output)") + key = einops.rearrange(key, "heads input output -> input (heads output)") + value = einops.rearrange(value, + "heads input output -> input (heads output)") + linear = einops.rearrange(linear, + "heads input output -> (heads input) output") + + params[f"{module_name}/query"]["w"][:, :] = query + params[f"{module_name}/key"]["w"][:, :] = key + params[f"{module_name}/value"]["w"][:, :] = value + params[f"{module_name}/linear"]["w"][:, :] = linear + + params = jax.tree_util.tree_map(jnp.array, params) + return AssembledTransformerModel( + forward=forward.apply, + get_compiled_model=get_compiled_model, + params=params, + model_config=model_config, + residual_labels=residual_labels, + ) diff --git a/compiler/assemble_test.py b/compiler/assemble_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9dce153038a02f6b0496dbdb936825f006dacd73 --- /dev/null +++ b/compiler/assemble_test.py @@ -0,0 +1,120 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transformer.assemble.""" + +from absl.testing import absltest +from absl.testing import parameterized +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from tracr.compiler import assemble +from tracr.craft import bases + + +class AssembleTest(parameterized.TestCase): + + def test_token_embedding_produces_correct_embedding(self): + # Token embeddings should be one-hot embeddings of the input integers + # into the token subspace of residual_space + input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) + indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) + output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) + residual_space = bases.join_vector_spaces(input_space, indices_space, + output_space) + + @hk.without_apply_rng + @hk.transform + def token_pos_embed(tokens): + embed_modules = assemble._make_embedding_modules( + residual_space=residual_space, + tokens_space=input_space, + indices_space=indices_space, + output_space=output_space) + return embed_modules.token_embed(tokens) + + tokens = jnp.array([0, 0, 1]) + expected_token_embeddings = jnp.array([[1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0]]) + + params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) + embeddings = token_pos_embed.apply(params, tokens) + np.testing.assert_allclose(embeddings, expected_token_embeddings) + + def test_position_embedding_produces_correct_embedding(self): + # Position embeddings should be one-hot embeddings of the input integers + # (representing indices) into the indices subspace of residual_space + input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) + indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) + output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) + residual_space = bases.join_vector_spaces(input_space, indices_space, + output_space) + + @hk.without_apply_rng + @hk.transform + def token_pos_embed(tokens): + embed_modules = assemble._make_embedding_modules( + residual_space=residual_space, + tokens_space=input_space, + indices_space=indices_space, + output_space=output_space) + return embed_modules.pos_embed(jnp.indices(tokens.shape)[-1]) + + tokens = jnp.array([3, 0, 0, 1]) + expected_pos_embeddings = jnp.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0]]) + + params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) + embeddings = token_pos_embed.apply(params, tokens) + np.testing.assert_allclose(embeddings, expected_pos_embeddings) + + def test_unembedding(self): + # Prepend numbers to preserve basis order [input, index, output] + input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) + indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) + output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) + residual_space = bases.join_vector_spaces(input_space, indices_space, + output_space) + + @hk.without_apply_rng + @hk.transform + def unembed(embeddings): + embed_modules = assemble._make_embedding_modules( + residual_space=residual_space, + tokens_space=input_space, + indices_space=indices_space, + output_space=output_space) + return embed_modules.unembed(embeddings, use_unembed_argmax=True) + + embeddings = jnp.array([ + # pylint: disable=g-no-space-after-comment + #inp| indices| out | < spaces + #0 1 0 1 2 0 1 < values in spaces + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1] + ]) + expected_tokens = jnp.array([1, 0, 1]) + + params = unembed.init(jax.random.PRNGKey(0), embeddings) + tokens = unembed.apply(params, embeddings) + np.testing.assert_allclose(tokens, expected_tokens) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/basis_inference.py b/compiler/basis_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e08b99d81d34246cfff6836a77768420d28c7b56 --- /dev/null +++ b/compiler/basis_inference.py @@ -0,0 +1,106 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inferring the vector spaces taken on by certain operations.""" + +import dataclasses +import itertools + +import networkx as nx +from tracr.compiler import nodes +from tracr.craft import bases +from tracr.rasp import rasp +from tracr.utils import errors + +Node = nodes.Node + + +@dataclasses.dataclass +class InferBasesOutput: + graph: nx.DiGraph + + +def infer_bases( + graph: nx.DiGraph, + sink: Node, + vocab: set[rasp.Value], + max_seq_len: int, +) -> None: + """Infers in-place the possible output values and vector bases of the SOps.""" + + def compute_value_set(sop: rasp.SOp) -> set[rasp.Value]: + """Computes value set using already-computed predecessor value sets.""" + if sop is rasp.tokens: + return vocab + elif sop is rasp.indices: + return set(range(max_seq_len)) + elif isinstance(sop, rasp.SelectorWidth): + return set(range(0, max_seq_len + 1)) + elif isinstance(sop, rasp.Full): + return {sop.fill} + elif isinstance(sop, rasp.Map): + inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET] + out = set() + for x in inner_value_set: + res = errors.ignoring_arithmetic_errors(sop.f)(x) + if res is not None: + out.add(res) + return out + elif isinstance(sop, rasp.SequenceMap): + f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) + fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET] + snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET] + out = set() + for l, r in itertools.product(fst_value_set, snd_value_set): + res = f_ignore_error(l, r) + if res is not None: + out.add(res) + return out + elif isinstance(sop, rasp.Aggregate): + if rasp.is_categorical(sop): + # Simply pass on the value set of the underlying S-Op. + return graph.nodes[sop.sop.label][nodes.VALUE_SET] + elif rasp.is_numerical(sop): + # TODO(b/255936408): This doesn't work if we average arbitrary values. + # But most examples only average binary variables. + sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET] + if {int(x) for x in sop_value_set} != {0, 1}: + raise NotImplementedError( + "Attention patterns can currently only " + "average binary variables. Not:", sop_value_set) + + value_set = set() + for value in sop_value_set: + for length in range(1, max_seq_len + 1): + value_set.add(value / length) + return value_set + raise ValueError(f"Unsupported S-Op: {sop}") + + for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]): + expr = graph.nodes[node_id][nodes.EXPR] + + if not isinstance(expr, rasp.SOp): + # Only S-Ops have output vector spaces. + continue + + value_set = compute_value_set(expr) + graph.nodes[node_id][nodes.VALUE_SET] = value_set + + if rasp.is_categorical(expr): + out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set) + elif rasp.is_numerical(expr): + out_space = bases.VectorSpaceWithBasis.from_names([expr.label]) + else: + raise ValueError(f"Unsupported S-Op type: {expr.type}") + graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis diff --git a/compiler/basis_inference_test.py b/compiler/basis_inference_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1725c84fd187c684fb16bc1adfd413400a3d8d5b --- /dev/null +++ b/compiler/basis_inference_test.py @@ -0,0 +1,140 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compiler.basis_inference.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.compiler import basis_inference +from tracr.compiler import nodes +from tracr.compiler import rasp_to_graph +from tracr.rasp import rasp + + +class InferBasesTest(parameterized.TestCase): + + def test_arithmetic_error_logs_warning(self): + program = rasp.numerical(rasp.Map(lambda x: 1 / x, rasp.tokens)) + extracted = rasp_to_graph.extract_rasp_graph(program) + vocab = {0, 1, 2} + with self.assertLogs(level="WARNING"): + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=1, + ) + + @parameterized.parameters(({1, 2, 3}, {2, 3, 4}), ({0, 5}, {1, 6})) + def test_one_edge(self, vocab, expected_value_set): + program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) + extracted = rasp_to_graph.extract_rasp_graph(program) + + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=1, + ) + + self.assertSetEqual( + extracted.graph.nodes[program.label][nodes.VALUE_SET], + expected_value_set, + ) + + def test_primitive_close_to_tip(self): + intermediate = rasp.categorical(rasp.tokens + 1) + intermediate = rasp.categorical(intermediate + intermediate) + program = rasp.categorical(intermediate + rasp.indices) + extracted = rasp_to_graph.extract_rasp_graph(program) + + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + {0, 1}, + max_seq_len=2, + ) + + self.assertSetEqual( + extracted.graph.nodes[program.label][nodes.VALUE_SET], + {2, 3, 4, 5}, + ) + self.assertSetEqual( + extracted.graph.nodes[intermediate.label][nodes.VALUE_SET], + {2, 3, 4}, + ) + + def test_categorical_aggregate(self): + program = rasp.categorical( + rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.indices, + )) + + extracted = rasp_to_graph.extract_rasp_graph(program) + + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + {0, 1}, + max_seq_len=3, + ) + + self.assertSetEqual( + extracted.graph.nodes[program.label][nodes.VALUE_SET], + {0, 1, 2}, + ) + + def test_numerical_aggregate(self): + program = rasp.numerical( + rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.indices, + )) + + extracted = rasp_to_graph.extract_rasp_graph(program) + + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + {0, 1}, + max_seq_len=2, + ) + + self.assertSetEqual( + extracted.graph.nodes[program.label][nodes.VALUE_SET], + {0, 1, 1 / 2}, + ) + + def test_selector_width(self): + program = rasp.SelectorWidth( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)) + + extracted = rasp_to_graph.extract_rasp_graph(program) + + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + {0, 1}, + max_seq_len=2, + ) + + self.assertSetEqual( + extracted.graph.nodes[program.label][nodes.VALUE_SET], + {0, 1, 2}, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/compiling.py b/compiler/compiling.py new file mode 100644 index 0000000000000000000000000000000000000000..31ba4c1a0ff30030935c49578631a99d8e0cceba --- /dev/null +++ b/compiler/compiling.py @@ -0,0 +1,92 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Combines all steps of compiling a RASP program.""" + +from tracr.compiler import assemble +from tracr.compiler import basis_inference +from tracr.compiler import craft_graph_to_model +from tracr.compiler import craft_model_to_transformer +from tracr.compiler import expr_to_craft_graph +from tracr.compiler import rasp_to_graph +from tracr.craft import bases +from tracr.rasp import rasp + +COMPILER_BOS = "compiler_bos" +COMPILER_PAD = "compiler_pad" + + +def compile_rasp_to_model( + program: rasp.SOp, + vocab: set[rasp.Value], + max_seq_len: int, + causal: bool = False, + compiler_bos: str = COMPILER_BOS, + compiler_pad: str = COMPILER_PAD, + mlp_exactness: int = 100) -> assemble.AssembledTransformerModel: + """Compile a RASP program to transformer weights. + + Args: + program: the RASP program to compile. + vocab: the set of vocab tokens expected by RASP. + max_seq_len: the maximum sequence length for the compiled model. + causal: if True, outputs a model with causal masking. + compiler_bos: the name of the special BOS token that will be added by the + compiler. Must not be present in the vocab. + compiler_pad: the name of the special PAD token that will be added by the + compiler. Must not be present in the vocab. + mlp_exactness: Controls the approximation of the MLP layers. In theory, + larger values yield a better approximation. But too large values can cause + numerical issues due to large parameter norms. Reasonable values are + between 1 and 100. + + Returns: + The compiled model. + """ + + if compiler_bos in vocab: + raise ValueError("Compiler BOS token must not be present in the vocab. " + f"Found '{compiler_bos}' in {vocab}") + + if compiler_pad in vocab: + raise ValueError("Compiler PAD token must not be present in the vocab. " + f"Found '{compiler_pad}' in {vocab}") + + extracted = rasp_to_graph.extract_rasp_graph(program) + graph, sources, sink = extracted.graph, extracted.sources, extracted.sink + + basis_inference.infer_bases( + graph, + sink, + vocab, + max_seq_len, + ) + + expr_to_craft_graph.add_craft_components_to_rasp_graph( + graph, + bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos), + mlp_exactness=mlp_exactness, + ) + + craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources) + + return craft_model_to_transformer.craft_model_to_transformer( + craft_model=craft_model, + graph=graph, + sink=sink, + max_seq_len=max_seq_len, + causal=causal, + compiler_bos=compiler_bos, + compiler_pad=compiler_pad, + ) diff --git a/compiler/craft_graph_to_model.py b/compiler/craft_graph_to_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f198368c5136a07de01a86afcbbe6b4a3940b5d8 --- /dev/null +++ b/compiler/craft_graph_to_model.py @@ -0,0 +1,238 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Create a craft model from a computational graph.""" + +import collections +from typing import Sequence + +import networkx as nx +from tracr.compiler import nodes +from tracr.craft import bases +from tracr.craft import transformers +from tracr.rasp import rasp + +Node = nodes.Node +NodeID = nodes.NodeID + + +def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node], + node: Node) -> int: + """Returns the lengths of the longest path from sources to node. + + Only SOps count towards the length of a path. + + Args: + graph: DAG to compute longest path in. + sources: List of starting nodes, longest path will be a maximum over all. + node: Target node. + + Returns: + Number of steps needed for the longest path from the source to the node, or + -1 if there is no path from any of the sources to the target node. + """ + if node in sources: + return 0 + + def num_sops(path: Sequence[NodeID]) -> int: + num = 0 + for node_id in path: + if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp): + num += 1 + return num + + result = -1 + for source in sources: + all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID]) + longest_path_len = max(map(num_sops, all_paths), default=-1) - 1 + if longest_path_len > result: + result = longest_path_len + return result + + +def _node_is_attn(node: Node) -> bool: + """Returns True if node is an attention layer.""" + return nodes.MODEL_BLOCK in node and isinstance( + node[nodes.MODEL_BLOCK], + (transformers.AttentionHead, transformers.MultiAttentionHead)) + + +def _node_is_mlp(node: Node) -> bool: + """Returns True if node is an MLP layer.""" + return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK], + transformers.MLP) + + +def _node_is_residual_block(node: Node) -> bool: + """Returns True if node is a valid residual block (Attn followed by MLP).""" + block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None + if block and isinstance(block, transformers.SeriesWithResiduals): + if len(block.blocks) == 2: + attn, mlp = block.blocks + if (isinstance( + attn, + (transformers.AttentionHead, transformers.MultiAttentionHead)) and + isinstance(mlp, transformers.MLP)): + return True + return False + + +def _all_attn_nodes(node_list: Sequence[Node]) -> bool: + """Returns True iff all nodes are attention layers (or nodes is empty).""" + for node in node_list: + if not _node_is_attn(node): + return False + return True + + +def _all_mlp_nodes(node_list: Sequence[Node]) -> bool: + """Returns True iff all nodes are MLP layers (or nodes is empty).""" + for node in node_list: + if not _node_is_mlp(node): + return False + return True + + +def _allocate_modules_to_layers(graph: nx.DiGraph, + sources: Sequence[Node]) -> dict[int, int]: + """Allocate all nodes in compute graph to layers. + + First, computes the longest path from the input to each node that is a model + component (not input and output nodes). The longest path to a model component + (its "depth") determines a layer in which we can place it while ensuring that + all necessary previous computations have already happened. + + This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...] + + In the special case where there are only Attention layers at one depth level + and only MLP layers in the next depth layer, they are treated as if there + are at the same depth because attention layers always come before MLP layers + for the same depth. + + Args: + graph: RASP graph with craft blocks. + sources: List of input nodes + + Returns: + A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ... + are in the order attention, mlp, attention, mlp, ... + """ + layer_allocation: dict[int, int] = collections.defaultdict(lambda: -1) + depth_by_node_id: dict[int, int] = dict() + nodes_by_depth: dict[int, list[Node]] = collections.defaultdict(list) + + # Compute depth of all model components (longest path from source to node) + for node_id, node in graph.nodes.items(): + if (_node_is_mlp(node) or _node_is_attn(node) + or _node_is_residual_block(node)): + # Node is a model component + longest_path_len = _get_longest_path_length_to_node(graph, sources, node) + depth_by_node_id[node_id] = longest_path_len + nodes_by_depth[longest_path_len].append(node) + + # If at level `depth` there are only attention heads and at level `depths + 1` + # there are only MLPs, we can condense them into one level + # TODO(b/255936816): Think about improving this heuristic. The heuristic is + # not optimal, and only catches very basic opportunities for optimization. It + # is easy to come up with opportunities for optimization that it does not + # catch. + min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys()) + depth = min_depth + while depth < max_depth: + if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes( + nodes_by_depth[depth + 1]): + # Condense by decrementing the depth of all nodes starting from depth+1 + for update_depth in range(depth + 1, max_depth + 1): + for node in nodes_by_depth[update_depth]: + node_id = node[nodes.ID] + depth_by_node_id[node_id] = update_depth - 1 + nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth]) + nodes_by_depth[update_depth] = [] + max_depth -= 1 + depth += 1 + + # Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ... + current_layer = 0 + current_depth = 1 + for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]): + while depth > current_depth: + current_depth += 1 + current_layer += 2 + if depth == current_depth: + if _node_is_residual_block(graph.nodes[node_id]): + layer_allocation[node_id] = current_layer + else: + is_mlp = _node_is_mlp(graph.nodes[node_id]) + layer_allocation[node_id] = current_layer + int(is_mlp) + + return layer_allocation + + +def craft_graph_to_model( + graph: nx.DiGraph, + sources: Sequence[Node]) -> transformers.SeriesWithResiduals: + """Translates a RASP graph with craft blocks into a full craft model. + + 1. Allocate modules to layers, assuming layers in the order + 2. Creates subspaces for all inputs and outputs, and builds residual stream. + 3. Assembles everything into a craft model and returns it. + + Args: + graph: RASP graph with craft blocks. + sources: List of input nodes + + Returns: + A craft model that can be compiled to model weights. + + Raises: + ValueError: On invalid input (if the craft_graph does not have craft blocks + already specified) + """ + layer_allocation = _allocate_modules_to_layers(graph, sources) + blocks_by_layer = collections.defaultdict(list) + model_blocks = [] + + residual_space = bases.VectorSpaceWithBasis([]) + + for node_id, layer_no in layer_allocation.items(): + node = graph.nodes[node_id] + block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None + + if _node_is_residual_block(node): + assert isinstance(block, transformers.SeriesWithResiduals) + assert len(block.blocks) == 2 + residual_space = bases.join_vector_spaces(residual_space, + block.blocks[0].residual_space, + block.blocks[1].residual_space) + blocks_by_layer[layer_no].append(block.blocks[0]) + blocks_by_layer[layer_no + 1].append(block.blocks[1]) + elif block: + residual_space = bases.join_vector_spaces( + residual_space, node[nodes.MODEL_BLOCK].residual_space) + blocks_by_layer[layer_no].append(block) + + for layer_no, layer_blocks in sorted( + blocks_by_layer.items(), key=lambda x: x[0]): + for block in layer_blocks: + block.residual_space = residual_space + + if layer_blocks: + if layer_no % 2 == 0: # Attention Layer + multi_head_attn = transformers.MultiAttentionHead(layer_blocks) + model_blocks.append(multi_head_attn) + else: # MLP Layer + parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks) + model_blocks.append(parallel_mlp) + + return transformers.SeriesWithResiduals(model_blocks) diff --git a/compiler/craft_graph_to_model_test.py b/compiler/craft_graph_to_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a0ba61e6b7279a42f3fb37f59244e0fccef5ee --- /dev/null +++ b/compiler/craft_graph_to_model_test.py @@ -0,0 +1,194 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compiler.craft_graph_to_model.""" + +from absl.testing import absltest +from absl.testing import parameterized +import networkx as nx +from tracr.compiler import craft_graph_to_model +from tracr.compiler import nodes +from tracr.compiler import rasp_to_graph +from tracr.craft import bases +from tracr.craft.chamber import categorical_attn +from tracr.craft.chamber import categorical_mlp +from tracr.rasp import rasp + + +class CraftAllocateModulesToLayersTest(parameterized.TestCase): + + def _get_dummy_block(self, block_type): + if block_type == "ATTN": + return categorical_attn.categorical_attn( + query_space=bases.VectorSpaceWithBasis.from_names(["query"]), + key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]), + value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]), + output_space=bases.VectorSpaceWithBasis.from_names(["output"]), + bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]), + one_space=bases.VectorSpaceWithBasis.from_names(["one"]), + attn_fn=lambda x, y: True, + ) + elif block_type == "MLP": + return categorical_mlp.map_categorical_mlp( + input_space=bases.VectorSpaceWithBasis.from_names(["input"]), + output_space=bases.VectorSpaceWithBasis.from_names(["output"]), + operation=lambda x: x, + ) + else: + return None + + def test_get_longest_path_length_to_node_returns_expected_result(self): + """Creates a graph and checks the longest path for each node.""" + + # Node IDs: + # 0 -- 1 -- 2 -- 3 ------------ 4 + # / / + # 5 -- 6 ---------- 7 -- 8 -- 9 + # + # 10 + # Expected return values: + # 0 -- 1 -- 2 -- 3 ------------ 5 + # / / + # 0 -- 1 ---------- 2 -- 3 -- 4 + # + # -1 + + graph = nx.DiGraph() + node_ids = list(range(11)) + expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1] + for node_id, res in zip(node_ids, expected_results): + graph.add_node( + node_id, **{ + nodes.ID: node_id, + nodes.EXPR: rasp.ConstantSOp(1), + "expected_result": res + }) + graph.add_edge(0, 1) + graph.add_edge(1, 2) + graph.add_edge(2, 3) + graph.add_edge(3, 4) + graph.add_edge(5, 6) + graph.add_edge(6, 7) + graph.add_edge(7, 8) + graph.add_edge(8, 9) + graph.add_edge(6, 3) + graph.add_edge(9, 4) + sources = [graph.nodes[0], graph.nodes[5]] + + for node_id, node in graph.nodes.items(): + result = craft_graph_to_model._get_longest_path_length_to_node( + graph, sources, node) + self.assertEqual(result, node["expected_result"]) + + def test_allocate_modules_to_layers_returns_expected_result(self): + """Creates a graph and checks if the correct layer assignment is returned.""" + + # Computation Graph: + # INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT + # / / / + # INPUT -- MLP --- MLP ATTN + # \ / + # ATTN + # Node IDs: + # 0 -- 1 -- 2 -- 3 -- 4 -- 5 + # / / / + # 6 -- 7 ---- 8 9 + # \ / + # 10 + # Expected layer allocation: + # -1 -- 0 -- 3 -- 4 -- 7 -- -1 + # / / / + # -1 -- 1 --- 3 6 + # \ / + # 4 + + graph = nx.DiGraph() + node_ids = list(range(11)) + types = [ + "INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP", + "ATTN", "ATTN" + ] + expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4] + for node_id, node_type, res in zip(node_ids, types, expected_results): + graph.add_node( + node_id, **{ + nodes.ID: node_id, + nodes.EXPR: rasp.ConstantSOp(1), + nodes.MODEL_BLOCK: self._get_dummy_block(node_type), + "expected_result": res + }) + + graph.add_edge(0, 1) + graph.add_edge(1, 2) + graph.add_edge(2, 3) + graph.add_edge(3, 4) + graph.add_edge(4, 5) + graph.add_edge(6, 7) + graph.add_edge(7, 2) + graph.add_edge(7, 8) + graph.add_edge(8, 3) + graph.add_edge(8, 10) + graph.add_edge(9, 4) + graph.add_edge(10, 9) + + craft_graph = rasp_to_graph.ExtractRaspGraphOutput( + graph=graph, + sink=graph.nodes[10], + sources=[graph.nodes[0], graph.nodes[6]]) + + layer_allocation = craft_graph_to_model._allocate_modules_to_layers( + craft_graph.graph, craft_graph.sources) + for node_id, node in graph.nodes.items(): + self.assertEqual(layer_allocation[node_id], node["expected_result"]) + + def test_allocate_modules_to_layers_returns_expected_result_for_chain(self): + """Tests a chain of alternating attention layers and MLPs.""" + + # Computation Graph: + # INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT + # Node IDs: + # 0 -- 1 -- 2 -- 3 -- 4 -- 5 + # Expected layer allocation: + # -1 -- 0 -- 1 -- 2 -- 3 -- -1 + + graph = nx.DiGraph() + node_ids = list(range(11)) + types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"] + expected_results = [-1, 0, 1, 2, 3, -1] + for node_id, node_type, res in zip(node_ids, types, expected_results): + graph.add_node( + node_id, **{ + nodes.ID: node_id, + nodes.EXPR: rasp.ConstantSOp(1), + nodes.MODEL_BLOCK: self._get_dummy_block(node_type), + "expected_result": res + }) + + graph.add_edge(0, 1) + graph.add_edge(1, 2) + graph.add_edge(2, 3) + graph.add_edge(3, 4) + graph.add_edge(4, 5) + + craft_graph = rasp_to_graph.ExtractRaspGraphOutput( + graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]]) + + layer_allocation = craft_graph_to_model._allocate_modules_to_layers( + craft_graph.graph, craft_graph.sources) + for node_id, node in graph.nodes.items(): + self.assertEqual(layer_allocation[node_id], node["expected_result"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/craft_model_to_transformer.py b/compiler/craft_model_to_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..551e9502102502e5ae29790d81a3275315cf0225 --- /dev/null +++ b/compiler/craft_model_to_transformer.py @@ -0,0 +1,76 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Convert craft model into transformer with the correct input/output spaces.""" + +import networkx as nx +from tracr.compiler import assemble +from tracr.compiler import nodes +from tracr.craft import bases +from tracr.craft import transformers +from tracr.rasp import rasp +from tracr.transformer import encoder + + +def craft_model_to_transformer( + craft_model: transformers.SeriesWithResiduals, + graph: nx.DiGraph, + sink: nodes.Node, + max_seq_len: int, + compiler_bos: str, + compiler_pad: str, + causal: bool = False, +) -> assemble.AssembledTransformerModel: + """Turn a craft model into a transformer model.""" + + # Add the compiler BOS token. + tokens_value_set = ( + graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union( + {compiler_bos, compiler_pad})) + tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label, + tokens_value_set) + + indices_space = bases.VectorSpaceWithBasis.from_values( + rasp.indices.label, range(max_seq_len)) + + categorical_output = rasp.is_categorical(sink[nodes.EXPR]) + output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS]) + + assembled_model = assemble.assemble_craft_model( + craft_model=craft_model, + tokens_space=tokens_space, + indices_space=indices_space, + output_space=output_space, + categorical_output=categorical_output, + causal=causal, + ) + + assembled_model.input_encoder = encoder.CategoricalEncoder( + basis=tokens_space.basis, + enforce_bos=compiler_bos is not None, + bos_token=compiler_bos, + pad_token=compiler_pad, + max_seq_len=max_seq_len + 1 if compiler_bos is not None else max_seq_len, + ) + + if categorical_output: + assembled_model.output_encoder = encoder.CategoricalEncoder( + basis=output_space.basis, + enforce_bos=False, + bos_token=None, + pad_token=None) + else: + assembled_model.output_encoder = encoder.NumericalEncoder() + + return assembled_model diff --git a/compiler/expr_to_craft_graph.py b/compiler/expr_to_craft_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..c70abd804f083cffd80e2f694051d75010cba6da --- /dev/null +++ b/compiler/expr_to_craft_graph.py @@ -0,0 +1,277 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Add craft model blocks to graph of RASPExpr.""" + +from typing import Any, Callable, Optional + +import networkx as nx +from tracr.compiler import nodes +from tracr.craft import bases +from tracr.craft.chamber import categorical_attn +from tracr.craft.chamber import categorical_mlp +from tracr.craft.chamber import numerical_mlp +from tracr.craft.chamber import selector_width +from tracr.rasp import rasp + + +def _transform_fun_to_basis_fun( + fun: Callable[..., Any], + output_direction_name: Optional[str] = None) -> Callable[..., Any]: + """Transforms a function acting on values into one acting on directions.""" + + def bases_fun(*args): + values = [d.value for d in args] + result = fun(*values) + if output_direction_name: + return bases.BasisDirection(output_direction_name, result) + return result + + return bases_fun + + +def _check_selector_expression(expr, graph): + """Check graph structure and encodings for an aggregate or selector width.""" + sel_expr = expr.selector + + # Check graph structure + assert sel_expr.label in graph.predecessors(expr.label) + assert sel_expr.keys.label in graph.predecessors(sel_expr.label) + assert sel_expr.queries.label in graph.predecessors(sel_expr.label) + + if (not rasp.is_categorical(sel_expr.queries) or + not rasp.is_categorical(sel_expr.keys)): + raise ValueError("Selector keys and queries must be categorical.") + + +def add_craft_components_to_rasp_graph( + graph: nx.DiGraph, + bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"), + one_dir: bases.BasisDirection = bases.BasisDirection("one"), + causal: bool = False, + mlp_exactness: float = 100, +) -> None: + """Translates expressions to craft blocks and attaches them to the graph. + + Sets the `MODEL_BLOCK` attribute for all nodes in `graph`. + + Args: + graph: RASP graph with `VALUE_SET` but not `MODEL_BLOCK` attributes. + bos_dir: Basis direction representing beginning of sequence (bos) token. + one_dir: Auxiliary basis direction that must contain 1. + causal: If True, marks attention blocks as causal. + mlp_exactness: Controls the approximation of the MLP layers. + + Raises: + ValueError: On invalid input (if `MODEL_BLOCK` is set already, or + `VALUE_SET` is not set already) + NotImplementedError: If the graph contains an unsupported expression. + """ + one_space = bases.VectorSpaceWithBasis([one_dir]) + + for node_id, node in graph.nodes.items(): + expr = node[nodes.EXPR] + + if not isinstance(expr, rasp.SOp): + continue + + if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]: + raise ValueError("Input graph cannot have model blocks set already.") + if nodes.VALUE_SET not in node: + raise ValueError( + "Craft components can only be added after basis inference.") + + if expr is rasp.tokens or expr is rasp.indices: + block = None + elif isinstance(expr, rasp.Map): + inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label] + assert inner_expr.label in graph.predecessors(node_id) + input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS]) + output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) + + if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr): + basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label) + block = categorical_mlp.map_categorical_mlp( + input_space=input_space, + output_space=output_space, + operation=basis_fun) + elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr): + block = categorical_mlp.map_categorical_to_numerical_mlp( + input_space=input_space, + output_space=output_space, + operation=expr.f, + ) + elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr): + block = numerical_mlp.map_numerical_to_categorical_mlp( + f=expr.f, + input_space=input_space, + output_space=output_space, + input_value_set=inner_node[nodes.VALUE_SET], + one_space=one_space, + hidden_name=f"_hidden_{expr.label}_", + large_number=mlp_exactness) + elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr): + block = numerical_mlp.map_numerical_mlp( + f=expr.f, + input_space=input_space, + output_space=output_space, + input_value_set=inner_node[nodes.VALUE_SET], + one_space=one_space, + hidden_name=f"_hidden_{expr.label}_", + large_number=mlp_exactness) + else: + raise NotImplementedError("Map does no support " + f"in_type '{inner_expr.type}' and" + f" out_type '{expr.type}'!") + + elif isinstance(expr, rasp.SequenceMap): + fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label] + snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label] + + # Check graph structure + assert fst_expr.label in graph.predecessors(node_id) + assert snd_expr.label in graph.predecessors(node_id) + + fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS]) + snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS]) + out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) + + if (isinstance(expr, rasp.LinearSequenceMap) and + not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))): + raise NotImplementedError("Linear SequenceMap only supports numerical " + "inputs/outputs.") + elif ( + not isinstance(expr, rasp.LinearSequenceMap) and + not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))): + raise NotImplementedError("(Non-linear) SequenceMap only supports " + "categorical inputs/outputs.") + + if isinstance(expr, rasp.LinearSequenceMap): + assert len(fst_space.basis) == 1 + assert len(snd_space.basis) == 1 + assert len(out_space.basis) == 1 + block = numerical_mlp.linear_sequence_map_numerical_mlp( + input1_basis_direction=fst_space.basis[0], + input2_basis_direction=snd_space.basis[0], + output_basis_direction=out_space.basis[0], + input1_factor=expr.fst_fac, + input2_factor=expr.snd_fac, + hidden_name=f"_hidden_{expr.label}_") + elif fst_space == snd_space: + # It's okay to use the local variable expr.f because it is + # only used within the same loop iteration to create the MLP. + # pylint: disable=cell-var-from-loop + basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x), + expr.label) + block = categorical_mlp.map_categorical_mlp( + input_space=fst_space, output_space=out_space, operation=basis_fun) + else: + basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label) + block = categorical_mlp.sequence_map_categorical_mlp( + input1_space=fst_space, + input2_space=snd_space, + output_space=out_space, + operation=basis_fun, + one_space=one_space, + hidden_name=f"_hidden_{expr.label}_") + elif isinstance(expr, rasp.Aggregate): + sel_expr: rasp.Select = expr.selector + agg_expr: rasp.Aggregate = expr + + if not isinstance(sel_expr, rasp.Select): + raise TypeError("Compiling composite Selectors is not supported. " + f"Got a {sel_expr}.") + + queries = graph.nodes[sel_expr.queries.label] + keys = graph.nodes[sel_expr.keys.label] + sop = graph.nodes[agg_expr.sop.label] + + _check_selector_expression(expr, graph) + assert agg_expr.sop.label in graph.predecessors(node_id) + if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr): + raise ValueError( + "sop encoding must match output encoding of the aggregate.") + if rasp.is_categorical(agg_expr) and agg_expr.default is not None: + raise ValueError("Default for a categorical aggregate must be None. " + f"Got {agg_expr.default}") + if rasp.is_numerical(agg_expr) and agg_expr.default != 0: + raise ValueError("Default for a numerical aggregate must be 0. " + f"Got {agg_expr.default}") + + bos_space = bases.VectorSpaceWithBasis([bos_dir]) + one_space = bases.VectorSpaceWithBasis([one_dir]) + query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS]) + key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS]) + value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS]) + output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) + + # Argument order is different in craft / transformers than RASP selectors + def attn_basis_fn(query: bases.BasisDirection, + key: bases.BasisDirection) -> bool: + # It's okay to use the local variable sel_expr because this function is + # only used within the same loop iteration to create an attention head. + # pylint: disable=cell-var-from-loop + selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) + return selector_basis_fn(key, query) + + block = categorical_attn.categorical_attn( + query_space=query_space, + key_space=key_space, + value_space=value_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=attn_basis_fn, + default_output=output_space.null_vector(), + causal=causal, + always_attend_to_bos=False, + use_bos_for_default_output=True, + softmax_coldness=100) + elif isinstance(expr, rasp.SelectorWidth): + sel_expr = expr.selector + queries = graph.nodes[sel_expr.queries.label] + keys = graph.nodes[sel_expr.keys.label] + _check_selector_expression(expr, graph) + + bos_space = bases.VectorSpaceWithBasis([bos_dir]) + query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS]) + key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS]) + output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) + + # Argument order is different in craft / transformers than RASP selectors + def attn_basis_fn(query: bases.BasisDirection, + key: bases.BasisDirection) -> bool: + # It's okay to use the local variable sel_expr because this function is + # only used within the same loop iteration to create an attention head. + selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) # pylint: disable=cell-var-from-loop + return selector_basis_fn(key, query) + + block = selector_width.selector_width( + query_space=query_space, + key_space=key_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=attn_basis_fn, + out_value_set=node[nodes.VALUE_SET], + categorical_output=rasp.is_categorical(expr), + causal=False, + softmax_coldness=100, + mlp_large_number=mlp_exactness, + label=expr.label) + else: + raise NotImplementedError(f"Expression {expr} cannot be translated to " + "a model component.") + + graph.nodes[node_id][nodes.MODEL_BLOCK] = block diff --git a/compiler/expr_to_craft_graph_test.py b/compiler/expr_to_craft_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..02bbe986cf05e003b93be8a96cac7f31766fecd6 --- /dev/null +++ b/compiler/expr_to_craft_graph_test.py @@ -0,0 +1,121 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compiler.expr_to_craft_graph.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.compiler import basis_inference +from tracr.compiler import expr_to_craft_graph +from tracr.compiler import lib +from tracr.compiler import nodes +from tracr.compiler import rasp_to_graph +from tracr.craft import bases +from tracr.craft import transformers +from tracr.rasp import rasp + + +class ExprToCraftGraphTest(parameterized.TestCase): + + def _check_block_types_are_correct(self, graph): + for _, node in graph.nodes.items(): + expr = node[nodes.EXPR] + if isinstance(expr, rasp.SOp): + block = node[nodes.MODEL_BLOCK] + if isinstance(expr, (rasp.Map, rasp.SequenceMap)): + self.assertIsInstance(block, transformers.MLP) + elif isinstance(expr, rasp.Aggregate): + self.assertIsInstance(block, transformers.AttentionHead) + + def _get_input_space_from_node(self, node): + block = node[nodes.MODEL_BLOCK] + if isinstance(block, transformers.MLP): + return block.fst.input_space + elif isinstance(block, transformers.AttentionHead): + return bases.join_vector_spaces(block.w_qk.left_space, + block.w_qk.right_space, + block.w_ov.input_space) + else: + return None + + def _check_spaces_are_consistent(self, graph): + """Check that for each edge the output is a subspace of the input.""" + for u, v in graph.edges: + u_node, v_node = graph.nodes[u], graph.nodes[v] + if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance( + v_node[nodes.EXPR], rasp.SOp): + u_out_basis = u_node[nodes.OUTPUT_BASIS] + u_out_space = bases.VectorSpaceWithBasis(u_out_basis) + v_in_space = self._get_input_space_from_node(v_node) + self.assertTrue(u_out_space.issubspace(v_in_space)) + + @parameterized.named_parameters( + dict( + testcase_name="single_map", + program=rasp.Map(lambda x: x + 1, rasp.tokens)), + dict( + testcase_name="single_sequence_map", + program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, + rasp.indices)), + dict( + testcase_name="single_select_aggregate", + program=rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.tokens, + )), + dict(testcase_name="reverse", program=lib.make_reverse(rasp.tokens)), + dict(testcase_name="length", program=lib.make_length())) + def test_compiling_rasp_programs(self, program): + vocab = {0, 1, 2} + extracted = rasp_to_graph.extract_rasp_graph(program) + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=3, + ) + expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) + self._check_block_types_are_correct(extracted.graph) + self._check_spaces_are_consistent(extracted.graph) + + def test_add_craft_components_raises_value_error_if_called_before_basis_inference( + self): + program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) + extracted = rasp_to_graph.extract_rasp_graph(program) + + with self.assertRaisesRegex( + ValueError, + r"^.*Craft components can only be added after basis inference.*$"): + expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) + + def test_add_craft_components_raises_value_error_if_called_twice(self): + vocab = {0, 1, 2} + program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) + extracted = rasp_to_graph.extract_rasp_graph(program) + + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=1, + ) + + expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) + with self.assertRaisesRegex( + ValueError, r"^.*Input graph cannot have model blocks set already.*$"): + expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/lib.py b/compiler/lib.py new file mode 100644 index 0000000000000000000000000000000000000000..2b6c0d2c2e6bfc13237b630db4676e60e1e81024 --- /dev/null +++ b/compiler/lib.py @@ -0,0 +1,371 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RASP programs only using the subset of RASP supported by the compiler.""" + +from typing import Sequence + +from tracr.rasp import rasp + +### Programs that work only under non-causal evaluation. + + +def make_length() -> rasp.SOp: + """Creates the `length` SOp using selector width primitive. + + Example usage: + length = make_length() + length("abcdefg") + >> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0] + + Returns: + length: SOp mapping an input to a sequence, where every element + is the length of that sequence. + """ + all_true_selector = rasp.Select( + rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector") + return rasp.SelectorWidth(all_true_selector).named("length") + + +length = make_length() + + +def make_reverse(sop: rasp.SOp) -> rasp.SOp: + """Create an SOp that reverses a sequence, using length primitive. + + Example usage: + reverse = make_reverse(rasp.tokens) + reverse("Hello") + >> ['o', 'l', 'l', 'e', 'H'] + + Args: + sop: an SOp + + Returns: + reverse : SOp that reverses the input sequence. + """ + opp_idx = (length - rasp.indices).named("opp_idx") + opp_idx = (opp_idx - 1).named("opp_idx-1") + reverse_selector = rasp.Select(rasp.indices, opp_idx, + rasp.Comparison.EQ).named("reverse_selector") + return rasp.Aggregate(reverse_selector, sop).named("reverse") + + +def make_pair_balance(sop: rasp.SOp, open_token: str, + close_token: str) -> rasp.SOp: + """Return fraction of previous open tokens minus the fraction of close tokens. + + (As implemented in the RASP paper.) + + If the outputs are always non-negative and end in 0, that implies the input + has balanced parentheses. + + Example usage: + num_l = make_pair_balance(rasp.tokens, "(", ")") + num_l("a()b(c))") + >> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8] + + Args: + sop: Input SOp. + open_token: Token that counts positive. + close_token: Token that counts negative. + + Returns: + pair_balance: SOp mapping an input to a sequence, where every element + is the fraction of previous open tokens minus previous close tokens. + """ + bools_open = rasp.numerical(sop == open_token).named("bools_open") + opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens") + + bools_close = rasp.numerical(sop == close_token).named("bools_close") + closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes") + + pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1)) + return pair_balance.named("pair_balance") + + +def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp: + """Returns 1 if a set of parentheses are balanced, 0 else. + + (As implemented in the RASP paper.) + + Example usage: + shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"]) + shuffle_dyck2("({)}") + >> [1, 1, 1, 1] + shuffle_dyck2("(){)}") + >> [0, 0, 0, 0, 0] + + Args: + pairs: List of pairs of open and close tokens that each should be balanced. + """ + assert len(pairs) >= 1 + + # Compute running balance of each type of parenthesis + balances = [] + for pair in pairs: + assert len(pair) == 2 + open_token, close_token = pair + balance = make_pair_balance( + rasp.tokens, open_token=open_token, + close_token=close_token).named(f"balance_{pair}") + balances.append(balance) + + # Check if balances where negative anywhere -> parentheses not balanced + any_negative = balances[0] < 0 + for balance in balances[1:]: + any_negative = any_negative | (balance < 0) + + # Convert to numerical SOp + any_negative = rasp.numerical(rasp.Map(lambda x: x, + any_negative)).named("any_negative") + + select_all = rasp.Select(rasp.indices, rasp.indices, + rasp.Comparison.TRUE).named("select_all") + has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative, + default=0)).named("has_neg") + + # Check if all balances are 0 at the end -> closed all parentheses + all_zero = balances[0] == 0 + for balance in balances[1:]: + all_zero = all_zero & (balance == 0) + + select_last = rasp.Select(rasp.indices, length - 1, + rasp.Comparison.EQ).named("select_last") + last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero") + + not_has_neg = (~has_neg).named("not_has_neg") + return (last_zero & not_has_neg).named("shuffle_dyck") + + +def make_shuffle_dyck2() -> rasp.SOp: + return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2") + + +def make_hist() -> rasp.SOp: + """Returns the number of times each token occurs in the input. + + (As implemented in the RASP paper.) + + Example usage: + hist = make_hist() + hist("abac") + >> [2, 1, 2, 1] + """ + same_tok = rasp.Select(rasp.tokens, rasp.tokens, + rasp.Comparison.EQ).named("same_tok") + return rasp.SelectorWidth(same_tok).named("hist") + + +def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp: + """Returns vals sorted by < relation on keys. + + Only supports unique keys. + + Example usage: + sort = make_sort(rasp.tokens, rasp.tokens) + sort([2, 4, 3, 1]) + >> [1, 2, 3, 4] + + Args: + vals: Values to sort. + keys: Keys for sorting. + """ + smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller") + target_pos = rasp.SelectorWidth(smaller).named("target_pos") + sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ) + return rasp.Aggregate(sel_new, vals).named("sort") + + +def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int, + min_key: float) -> rasp.SOp: + """Returns vals sorted by < relation on keys, which don't need to be unique. + + The implementation differs from the RASP paper, as it avoids using + compositions of selectors to break ties. Instead, it uses the arguments + max_seq_len and min_key to ensure the keys are unique. + + Note that this approach only works for numerical keys. + + Example usage: + sort = make_sort(rasp.tokens, rasp.tokens, 5, 1) + sort([2, 4, 3, 1]) + >> [1, 2, 3, 4] + sort([2, 4, 1, 2]) + >> [1, 2, 2, 4] + + Args: + vals: Values to sort. + keys: Keys for sorting. + max_seq_len: Maximum sequence length (used to ensure keys are unique) + min_key: Minimum key value (used to ensure keys are unique) + + Returns: + Output SOp of sort program. + """ + keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys, + rasp.indices) + return make_sort_unique(vals, keys) + + +def make_sort_freq(max_seq_len: int) -> rasp.SOp: + """Returns tokens sorted by the frequency they appear in the input. + + Tokens the appear the same amount of times are output in the same order as in + the input. + + Example usage: + sort = make_sort_freq(rasp.tokens, rasp.tokens, 5) + sort([2, 4, 2, 1]) + >> [2, 2, 4, 1] + + Args: + max_seq_len: Maximum sequence length (used to ensure keys are unique) + """ + hist = -1 * make_hist().named("hist") + return make_sort( + rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq") + + +### Programs that work under both causal and regular evaluation. + + +def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp: + """Count the fraction of previous tokens where a specific condition was True. + + (As implemented in the RASP paper.) + + Example usage: + num_l = make_frac_prevs(rasp.tokens=="l") + num_l("hello") + >> [0, 0, 1/3, 1/2, 2/5] + + Args: + bools: SOp mapping a sequence to a sequence of booleans. + + Returns: + frac_prevs: SOp mapping an input to a sequence, where every element + is the fraction of previous "True" tokens. + """ + bools = rasp.numerical(bools) + prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ) + return rasp.numerical(rasp.Aggregate(prevs, bools, + default=0)).named("frac_prevs") + + +def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp: + """Returns the sop, shifted by `offset`, None-padded.""" + select_off_by_offset = rasp.Select(rasp.indices, rasp.indices, + lambda k, q: q == k + offset) + out = rasp.Aggregate(select_off_by_offset, sop, default=None) + return out.named(f"shift_by({offset})") + + +def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp: + """Returns an SOp which is True at the final element of the pattern. + + The first len(pattern) - 1 elements of the output SOp are None-padded. + + detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T] + + Args: + sop: the SOp in which to look for patterns. + pattern: a sequence of values to look for. + + Returns: + a sop which detects the pattern. + """ + + if len(pattern) < 1: + raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}") + + # detectors[i] will be a boolean-valued SOp which is true at position j iff + # the i'th (from the end) element of the pattern was detected at position j-i. + detectors = [] + for i, element in enumerate(reversed(pattern)): + detector = sop == element + if i != 0: + detector = shift_by(i, detector) + detectors.append(detector) + + # All that's left is to take the AND over all detectors. + pattern_detected = detectors.pop() + while detectors: + pattern_detected = pattern_detected & detectors.pop() + + return pattern_detected.named(f"detect_pattern({pattern})") + + +def make_count_less_freq(n: int) -> rasp.SOp: + """Returns how many tokens appear fewer than n times in the input. + + The output sequence contains this count in each position. + + Example usage: + count_less_freq = make_count_less_freq(2) + count_less_freq(["a", "a", "a", "b", "b", "c"]) + >> [3, 3, 3, 3, 3, 3] + count_less_freq(["a", "a", "c", "b", "b", "c"]) + >> [6, 6, 6, 6, 6, 6] + + Args: + n: Integer to compare token frequences to. + """ + hist = make_hist().named("hist") + select_less = rasp.Select(hist, hist, + lambda x, y: x <= n).named("select_less") + return rasp.SelectorWidth(select_less).named("count_less_freq") + + +def make_count(sop, token): + """Returns the count of `token` in `sop`. + + The output sequence contains this count in each position. + + Example usage: + count = make_count(tokens, "a") + count(["a", "a", "a", "b", "b", "c"]) + >> [3, 3, 3, 3, 3, 3] + count(["c", "a", "b", "c"]) + >> [1, 1, 1, 1] + + Args: + sop: Sop to count tokens in. + token: Token to count. + """ + return rasp.SelectorWidth(rasp.Select( + sop, sop, lambda k, q: k == token)).named(f"count_{token}") + + +def make_nary_sequencemap(f, *sops): + """Returns an SOp that simulates an n-ary SequenceMap. + + Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n + into a single SOp arguments that takes n-tuples as value. The n-ary sequence + map implementing f is then a Map on this resulting SOp. + + Note that the intermediate variables representing tuples of varying length + will be encoded categorically, and can become very high-dimensional. So, + using this function might lead to very large compiled models. + + Args: + f: Function with n arguments. + *sops: Sequence of SOps, one for each argument of f. + """ + values, *sops = sops + for sop in sops: + # x is a single entry in the first iteration but a tuple in later iterations + values = rasp.SequenceMap( + lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop) + return rasp.Map(lambda args: f(*args), values) diff --git a/compiler/lib_test.py b/compiler/lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e2484b6f735cbd33f9ad79cd2e05906b7b73b1f4 --- /dev/null +++ b/compiler/lib_test.py @@ -0,0 +1,40 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compiler.lib.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.compiler import test_cases +from tracr.rasp import causal_eval +from tracr.rasp import rasp + + +class LibTest(parameterized.TestCase): + + @parameterized.named_parameters(*test_cases.TEST_CASES) + def test_program_produces_expected_output(self, program, test_input, + expected_output, **kwargs): + del kwargs + self.assertEqual(rasp.evaluate(program, test_input), expected_output) + + @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES) + def test_causal_program_produces_expected_output(self, program, test_input, + expected_output, **kwargs): + del kwargs + self.assertEqual(causal_eval.evaluate(program, test_input), expected_output) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/nodes.py b/compiler/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c8fab35feae3222f44e8afcec60fd70c6095fa --- /dev/null +++ b/compiler/nodes.py @@ -0,0 +1,32 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Documents the data stored in nodes after each compiler pass.""" + +from typing import Any + +Node = dict[str, Any] +NodeID = str + +# RASP -> Graph +ID = "ID" # unique ID of the node +EXPR = "EXPR" # the RASPExpr of the node + +# Basis inference +# Note that only S-Op expressions will have these keys set. +VALUE_SET = "VALUE_SET" # possible values taken on by this SOp. +OUTPUT_BASIS = "OUTPUT_BASIS" # the corresponding named basis. + +# RASP Graph -> Craft Graph +MODEL_BLOCK = "MODEL_BLOCK" # craft block representing a RASPExpr diff --git a/compiler/rasp_to_craft_integration_test.py b/compiler/rasp_to_craft_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4168796babd1c6196c5ad64dde23778390687ef --- /dev/null +++ b/compiler/rasp_to_craft_integration_test.py @@ -0,0 +1,254 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration tests for the RASP -> craft stages of the compiler.""" + +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tracr.compiler import basis_inference +from tracr.compiler import craft_graph_to_model +from tracr.compiler import expr_to_craft_graph +from tracr.compiler import nodes +from tracr.compiler import rasp_to_graph +from tracr.compiler import test_cases +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.rasp import rasp + +_BOS_DIRECTION = "rasp_to_transformer_integration_test_BOS" +_ONE_DIRECTION = "rasp_to_craft_integration_test_ONE" + + +def _make_input_space(vocab, max_seq_len): + tokens_space = bases.VectorSpaceWithBasis.from_values("tokens", vocab) + indices_space = bases.VectorSpaceWithBasis.from_values( + "indices", range(max_seq_len)) + one_space = bases.VectorSpaceWithBasis.from_names([_ONE_DIRECTION]) + bos_space = bases.VectorSpaceWithBasis.from_names([_BOS_DIRECTION]) + input_space = bases.join_vector_spaces(tokens_space, indices_space, one_space, + bos_space) + + return input_space + + +def _embed_input(input_seq, input_space): + bos_vec = input_space.vector_from_basis_direction( + bases.BasisDirection(_BOS_DIRECTION)) + one_vec = input_space.vector_from_basis_direction( + bases.BasisDirection(_ONE_DIRECTION)) + embedded_input = [bos_vec + one_vec] + for i, val in enumerate(input_seq): + i_vec = input_space.vector_from_basis_direction( + bases.BasisDirection("indices", i)) + val_vec = input_space.vector_from_basis_direction( + bases.BasisDirection("tokens", val)) + embedded_input.append(i_vec + val_vec + one_vec) + return bases.VectorInBasis.stack(embedded_input) + + +def _embed_output(output_seq, output_space, categorical_output): + embedded_output = [] + output_label = output_space.basis[0].name + for x in output_seq: + if x is None: + out_vec = output_space.null_vector() + elif categorical_output: + out_vec = output_space.vector_from_basis_direction( + bases.BasisDirection(output_label, x)) + else: + out_vec = x * output_space.vector_from_basis_direction( + output_space.basis[0]) + embedded_output.append(out_vec) + return bases.VectorInBasis.stack(embedded_output) + + +class CompilerIntegrationTest(tests_common.VectorFnTestCase): + + @parameterized.named_parameters( + dict( + testcase_name="map", + program=rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))), + dict( + testcase_name="sequence_map", + program=rasp.categorical( + rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.indices))), + dict( + testcase_name="sequence_map_with_same_input", + program=rasp.categorical( + rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens))), + dict( + testcase_name="select_aggregate", + program=rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), + rasp.Map(lambda x: 1, rasp.tokens)))) + def test_rasp_program_and_craft_model_produce_same_output(self, program): + vocab = {0, 1, 2} + max_seq_len = 3 + + extracted = rasp_to_graph.extract_rasp_graph(program) + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=max_seq_len, + ) + expr_to_craft_graph.add_craft_components_to_rasp_graph( + extracted.graph, + bos_dir=bases.BasisDirection(_BOS_DIRECTION), + one_dir=bases.BasisDirection(_ONE_DIRECTION), + ) + model = craft_graph_to_model.craft_graph_to_model(extracted.graph, + extracted.sources) + input_space = _make_input_space(vocab, max_seq_len) + output_space = bases.VectorSpaceWithBasis( + extracted.sink[nodes.OUTPUT_BASIS]) + + for val in vocab: + test_input = _embed_input([val], input_space) + rasp_output = program([val]) + expected_output = _embed_output( + output_seq=rasp_output, + output_space=output_space, + categorical_output=True) + test_output = model.apply(test_input).project(output_space) + self.assertVectorAllClose( + tests_common.strip_bos_token(test_output), expected_output) + + @parameterized.named_parameters(*test_cases.TEST_CASES) + def test_compiled_models_produce_expected_output(self, program, vocab, + test_input, expected_output, + max_seq_len, **kwargs): + del kwargs + categorical_output = rasp.is_categorical(program) + + extracted = rasp_to_graph.extract_rasp_graph(program) + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=max_seq_len, + ) + expr_to_craft_graph.add_craft_components_to_rasp_graph( + extracted.graph, + bos_dir=bases.BasisDirection(_BOS_DIRECTION), + one_dir=bases.BasisDirection(_ONE_DIRECTION), + ) + model = craft_graph_to_model.craft_graph_to_model(extracted.graph, + extracted.sources) + input_space = _make_input_space(vocab, max_seq_len) + output_space = bases.VectorSpaceWithBasis( + extracted.sink[nodes.OUTPUT_BASIS]) + if not categorical_output: + self.assertLen(output_space.basis, 1) + + test_input_vector = _embed_input(test_input, input_space) + expected_output_vector = _embed_output( + output_seq=expected_output, + output_space=output_space, + categorical_output=categorical_output) + test_output = model.apply(test_input_vector).project(output_space) + self.assertVectorAllClose( + tests_common.strip_bos_token(test_output), expected_output_vector) + + @unittest.expectedFailure + def test_setting_default_values_can_lead_to_wrong_outputs_in_compiled_model( + self, program): + # This is an example program in which setting a default value for aggregate + # writes a value to the bos token position, which interfers with a later + # aggregate operation causing the compiled model to have the wrong output. + + vocab = {"a", "b"} + test_input = ["a"] + max_seq_len = 2 + + # RASP: [False, True] + # compiled: [False, False, True] + not_a = rasp.Map(lambda x: x != "a", rasp.tokens) + + # RASP: + # [[True, False], + # [False, False]] + # compiled: + # [[False,True, False], + # [True, False, False]] + sel1 = rasp.Select(rasp.tokens, rasp.tokens, + lambda k, q: k == "a" and q == "a") + + # RASP: [False, True] + # compiled: [True, False, True] + agg1 = rasp.Aggregate(sel1, not_a, default=True) + + # RASP: + # [[False, True] + # [True, True]] + # compiled: + # [[True, False, False] + # [True, False, False]] + # because pre-softmax we get + # [[1.5, 1, 1] + # [1.5, 1, 1]] + # instead of + # [[0.5, 1, 1] + # [0.5, 1, 1]] + # Because agg1 = True is stored on the BOS token position + sel2 = rasp.Select(agg1, agg1, lambda k, q: k or q) + + # RASP: [1, 0.5] + # compiled + # [1, 1, 1] + program = rasp.numerical( + rasp.Aggregate(sel2, rasp.numerical(not_a), default=1)) + expected_output = [1, 0.5] + + # RASP program gives the correct output + program_output = program(test_input) + np.testing.assert_allclose(program_output, expected_output) + + extracted = rasp_to_graph.extract_rasp_graph(program) + basis_inference.infer_bases( + extracted.graph, + extracted.sink, + vocab, + max_seq_len=max_seq_len, + ) + expr_to_craft_graph.add_craft_components_to_rasp_graph( + extracted.graph, + bos_dir=bases.BasisDirection(_BOS_DIRECTION), + one_dir=bases.BasisDirection(_ONE_DIRECTION), + ) + model = craft_graph_to_model.craft_graph_to_model(extracted.graph, + extracted.sources) + + input_space = _make_input_space(vocab, max_seq_len) + output_space = bases.VectorSpaceWithBasis( + extracted.sink[nodes.OUTPUT_BASIS]) + + test_input_vector = _embed_input(test_input, input_space) + expected_output_vector = _embed_output( + output_seq=expected_output, + output_space=output_space, + categorical_output=True) + compiled_model_output = model.apply(test_input_vector).project(output_space) + + # Compiled craft model gives correct output + self.assertVectorAllClose( + tests_common.strip_bos_token(compiled_model_output), + expected_output_vector) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/rasp_to_graph.py b/compiler/rasp_to_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..487d467483e63553cecb40d7c06847db798cd740 --- /dev/null +++ b/compiler/rasp_to_graph.py @@ -0,0 +1,67 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converting a RaspExpr to a graph.""" + +import dataclasses +import queue + +import networkx as nx +from tracr.compiler import nodes +from tracr.rasp import rasp + +Node = nodes.Node +NodeID = nodes.NodeID + + +@dataclasses.dataclass +class ExtractRaspGraphOutput: + graph: nx.DiGraph + sink: Node # the program's output. + sources: list[Node] # the primitive S-Ops. + + +def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput: + """Converts a RASP program into a graph representation.""" + expr_queue = queue.Queue() + graph = nx.DiGraph() + sources: list[NodeID] = [] + + def ensure_node(expr: rasp.RASPExpr) -> NodeID: + """Finds or creates a graph node corresponding to expr; returns its ID.""" + node_id = expr.label + if node_id not in graph: + graph.add_node(node_id, **{nodes.ID: node_id, nodes.EXPR: expr}) + + return node_id + + # Breadth-first search over the RASP expression graph. + + def visit_raspexpr(expr: rasp.RASPExpr): + parent_id = ensure_node(expr) + + for child_expr in expr.children: + expr_queue.put(child_expr) + child_id = ensure_node(child_expr) + graph.add_edge(child_id, parent_id) + + if not expr.children: + sources.append(graph.nodes[parent_id]) + + expr_queue.put(tip) + sink = graph.nodes[ensure_node(tip)] + while not expr_queue.empty(): + visit_raspexpr(expr_queue.get()) + + return ExtractRaspGraphOutput(graph=graph, sink=sink, sources=sources) diff --git a/compiler/rasp_to_graph_test.py b/compiler/rasp_to_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ef68bb23d18d56ec8de54daa7348c09acbee475b --- /dev/null +++ b/compiler/rasp_to_graph_test.py @@ -0,0 +1,71 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compiler.rasp_to_graph.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.compiler import nodes +from tracr.compiler import rasp_to_graph +from tracr.rasp import rasp + + +class ExtractRaspGraphTest(parameterized.TestCase): + + def test_primitives_have_no_edges(self): + tokens_graph = rasp_to_graph.extract_rasp_graph(rasp.tokens).graph + self.assertEmpty(tokens_graph.edges) + + indices_graph = rasp_to_graph.extract_rasp_graph(rasp.indices).graph + self.assertEmpty(indices_graph.edges) + + full_graph = rasp_to_graph.extract_rasp_graph(rasp.Full(1)).graph + self.assertEmpty(full_graph.edges) + + def test_one_edge(self): + program = rasp.Map(lambda x: x + 1, rasp.tokens) + + graph = rasp_to_graph.extract_rasp_graph(program).graph + + self.assertLen(graph.edges, 1) + (u, v), = graph.edges + self.assertEqual(graph.nodes[u][nodes.EXPR], rasp.tokens) + self.assertEqual(graph.nodes[v][nodes.EXPR], program) + + def test_aggregate(self): + program = rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.indices, + ) + + extracted = rasp_to_graph.extract_rasp_graph(program) + + # Expected graph: + # + # indices \ -------- + # \ \ + # select -- program + # tokens / + + self.assertLen(extracted.graph.edges, 4) + self.assertEqual(extracted.sink[nodes.EXPR], program) + for source in extracted.sources: + self.assertIn( + source[nodes.EXPR], + [rasp.tokens, rasp.indices], + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/rasp_to_transformer_integration_test.py b/compiler/rasp_to_transformer_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..39287b95469336646e28c410853c2c46e02ab318 --- /dev/null +++ b/compiler/rasp_to_transformer_integration_test.py @@ -0,0 +1,214 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration tests for the full RASP -> transformer compilation.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as np + +from tracr.compiler import compiling +from tracr.compiler import lib +from tracr.compiler import test_cases +from tracr.craft import tests_common +from tracr.rasp import rasp + +_COMPILER_BOS = "rasp_to_transformer_integration_test_BOS" +_COMPILER_PAD = "rasp_to_transformer_integration_test_PAD" + +# Force float32 precision on TPU, which otherwise defaults to float16. +jax.config.update("jax_default_matmul_precision", "float32") + + +class CompilerIntegrationTest(tests_common.VectorFnTestCase): + + def assertSequenceEqualWhenExpectedIsNotNone(self, actual_seq, expected_seq): + for actual, expected in zip(actual_seq, expected_seq): + if expected is not None and actual != expected: + self.fail(f"{actual_seq} does not match (ignoring Nones) " + f"{expected_seq=}") + + @parameterized.named_parameters( + dict( + testcase_name="map", + program=rasp.Map(lambda x: x + 1, rasp.tokens)), + dict( + testcase_name="sequence_map", + program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, + rasp.indices)), + dict( + testcase_name="sequence_map_with_same_input", + program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, + rasp.indices)), + dict( + testcase_name="select_aggregate", + program=rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), + rasp.Map(lambda x: 1, rasp.tokens)))) + def test_rasp_program_and_transformer_produce_same_output(self, program): + vocab = {0, 1, 2} + max_seq_len = 3 + assembled_model = compiling.compile_rasp_to_model( + program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS) + + test_outputs = {} + rasp_outputs = {} + for val in vocab: + test_outputs[val] = assembled_model.apply([_COMPILER_BOS, val]).decoded[1] + rasp_outputs[val] = program([val])[0] + + with self.subTest(val=0): + self.assertEqual(test_outputs[0], rasp_outputs[0]) + with self.subTest(val=1): + self.assertEqual(test_outputs[1], rasp_outputs[1]) + with self.subTest(val=2): + self.assertEqual(test_outputs[2], rasp_outputs[2]) + + @parameterized.named_parameters(*test_cases.TEST_CASES) + def test_compiled_models_produce_expected_output(self, program, vocab, + test_input, expected_output, + max_seq_len, **kwargs): + del kwargs + assembled_model = compiling.compile_rasp_to_model( + program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS) + test_output = assembled_model.apply([_COMPILER_BOS] + test_input) + + if isinstance(expected_output[0], (int, float)): + np.testing.assert_allclose( + test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005) + else: + self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:], + expected_output) + + @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES) + def test_compiled_causal_models_produce_expected_output( + self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): + del kwargs + assembled_model = compiling.compile_rasp_to_model( + program, + vocab, + max_seq_len, + causal=True, + compiler_bos=_COMPILER_BOS, + compiler_pad=_COMPILER_PAD) + test_output = assembled_model.apply([_COMPILER_BOS] + test_input) + + if isinstance(expected_output[0], (int, float)): + np.testing.assert_allclose( + test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005) + else: + self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:], + expected_output) + + @parameterized.named_parameters( + dict( + testcase_name="reverse_1", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=list("dcba"), + max_seq_len=5), + dict( + testcase_name="reverse_2", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("abc"), + expected_output=list("cba"), + max_seq_len=5), + dict( + testcase_name="reverse_3", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("ad"), + expected_output=list("da"), + max_seq_len=5), + dict( + testcase_name="reverse_4", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=["c"], + expected_output=["c"], + max_seq_len=5), + dict( + testcase_name="length_categorical_1", + program=rasp.categorical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("abc"), + expected_output=[3, 3, 3], + max_seq_len=5), + dict( + testcase_name="length_categorical_2", + program=rasp.categorical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("ad"), + expected_output=[2, 2], + max_seq_len=5), + dict( + testcase_name="length_categorical_3", + program=rasp.categorical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=["c"], + expected_output=[1], + max_seq_len=5), + dict( + testcase_name="length_numerical_1", + program=rasp.numerical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("abc"), + expected_output=[3, 3, 3], + max_seq_len=5), + dict( + testcase_name="length_numerical_2", + program=rasp.numerical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("ad"), + expected_output=[2, 2], + max_seq_len=5), + dict( + testcase_name="length_numerical_3", + program=rasp.numerical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=["c"], + expected_output=[1], + max_seq_len=5), + ) + def test_compiled_models_produce_expected_output_with_padding( + self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): + del kwargs + assembled_model = compiling.compile_rasp_to_model( + program, + vocab, + max_seq_len, + compiler_bos=_COMPILER_BOS, + compiler_pad=_COMPILER_PAD) + + pad_len = (max_seq_len - len(test_input)) + test_input = test_input + [_COMPILER_PAD] * pad_len + test_input = [_COMPILER_BOS] + test_input + test_output = assembled_model.apply(test_input) + output = test_output.decoded + output_len = len(output) + output_stripped = test_output.decoded[1:output_len - pad_len] + + self.assertEqual(output[0], _COMPILER_BOS) + if isinstance(expected_output[0], (int, float)): + np.testing.assert_allclose( + output_stripped, expected_output, atol=1e-7, rtol=0.005) + else: + self.assertEqual(output_stripped, expected_output) + + +if __name__ == "__main__": + absltest.main() diff --git a/compiler/test_cases.py b/compiler/test_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3ac28ec6fa7c1c112f3193155d324edfd1cd8b --- /dev/null +++ b/compiler/test_cases.py @@ -0,0 +1,357 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A set of RASP programs and input/output pairs used in integration tests.""" + +from tracr.compiler import lib +from tracr.rasp import rasp + +UNIVERSAL_TEST_CASES = [ + dict( + testcase_name="frac_prevs_1", + program=lib.make_frac_prevs(rasp.tokens == "l"), + vocab={"h", "e", "l", "o"}, + test_input=list("hello"), + expected_output=[0.0, 0.0, 1 / 3, 1 / 2, 2 / 5], + max_seq_len=5), + dict( + testcase_name="frac_prevs_2", + program=lib.make_frac_prevs(rasp.tokens == "("), + vocab={"a", "b", "c", "(", ")"}, + test_input=list("a()b(c))"), + expected_output=[0.0, 1 / 2, 1 / 3, 1 / 4, 2 / 5, 2 / 6, 2 / 7, 2 / 8], + max_seq_len=10), + dict( + testcase_name="frac_prevs_3", + program=lib.make_frac_prevs(rasp.tokens == ")"), + vocab={"a", "b", "c", "(", ")"}, + test_input=list("a()b(c))"), + expected_output=[0.0, 0.0, 1 / 3, 1 / 4, 1 / 5, 1 / 6, 2 / 7, 3 / 8], + max_seq_len=10, + ), + dict( + testcase_name="shift_by_one", + program=lib.shift_by(1, rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=[None, "a", "b", "c"], + max_seq_len=5, + ), + dict( + testcase_name="shift_by_two", + program=lib.shift_by(2, rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=[None, None, "a", "b"], + max_seq_len=5, + ), + dict( + testcase_name="detect_pattern_a", + program=lib.detect_pattern(rasp.tokens, "a"), + vocab={"a", "b", "c", "d"}, + test_input=list("bacd"), + expected_output=[False, True, False, False], + max_seq_len=5, + ), + dict( + testcase_name="detect_pattern_ab", + program=lib.detect_pattern(rasp.tokens, "ab"), + vocab={"a", "b"}, + test_input=list("aaba"), + expected_output=[None, False, True, False], + max_seq_len=5, + ), + dict( + testcase_name="detect_pattern_ab_2", + program=lib.detect_pattern(rasp.tokens, "ab"), + vocab={"a", "b"}, + test_input=list("abaa"), + expected_output=[None, True, False, False], + max_seq_len=5, + ), + dict( + testcase_name="detect_pattern_ab_3", + program=lib.detect_pattern(rasp.tokens, "ab"), + vocab={"a", "b"}, + test_input=list("aaaa"), + expected_output=[None, False, False, False], + max_seq_len=5, + ), + dict( + testcase_name="detect_pattern_abc", + program=lib.detect_pattern(rasp.tokens, "abc"), + vocab={"a", "b", "c"}, + test_input=list("abcabc"), + expected_output=[None, None, True, False, False, True], + max_seq_len=6, + ), +] + +TEST_CASES = UNIVERSAL_TEST_CASES + [ + dict( + testcase_name="reverse_1", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=list("dcba"), + max_seq_len=5), + dict( + testcase_name="reverse_2", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("abc"), + expected_output=list("cba"), + max_seq_len=5), + dict( + testcase_name="reverse_3", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=list("ad"), + expected_output=list("da"), + max_seq_len=5), + dict( + testcase_name="reverse_4", + program=lib.make_reverse(rasp.tokens), + vocab={"a", "b", "c", "d"}, + test_input=["c"], + expected_output=["c"], + max_seq_len=5), + dict( + testcase_name="length_categorical_1", + program=rasp.categorical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("abc"), + expected_output=[3, 3, 3], + max_seq_len=3), + dict( + testcase_name="length_categorical_2", + program=rasp.categorical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("ad"), + expected_output=[2, 2], + max_seq_len=3), + dict( + testcase_name="length_categorical_3", + program=rasp.categorical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=["c"], + expected_output=[1], + max_seq_len=3), + dict( + testcase_name="length_numerical_1", + program=rasp.numerical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("abc"), + expected_output=[3, 3, 3], + max_seq_len=3), + dict( + testcase_name="length_numerical_2", + program=rasp.numerical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=list("ad"), + expected_output=[2, 2], + max_seq_len=3), + dict( + testcase_name="length_numerical_3", + program=rasp.numerical(lib.make_length()), + vocab={"a", "b", "c", "d"}, + test_input=["c"], + expected_output=[1], + max_seq_len=3), + dict( + testcase_name="pair_balance_1", + program=lib.make_pair_balance(rasp.tokens, "(", ")"), + vocab={"a", "b", "c", "(", ")"}, + test_input=list("a()b(c))"), + expected_output=[0.0, 1 / 2, 0.0, 0.0, 1 / 5, 1 / 6, 0.0, -1 / 8], + max_seq_len=10), + dict( + testcase_name="shuffle_dyck2_1", + program=lib.make_shuffle_dyck(pairs=["()", "{}"]), + vocab={"(", ")", "{", "}"}, + test_input=list("({)}"), + expected_output=[1, 1, 1, 1], + max_seq_len=5), + dict( + testcase_name="shuffle_dyck2_2", + program=lib.make_shuffle_dyck(pairs=["()", "{}"]), + vocab={"(", ")", "{", "}"}, + test_input=list("(){)}"), + expected_output=[0, 0, 0, 0, 0], + max_seq_len=5), + dict( + testcase_name="shuffle_dyck2_3", + program=lib.make_shuffle_dyck(pairs=["()", "{}"]), + vocab={"(", ")", "{", "}"}, + test_input=list("{}("), + expected_output=[0, 0, 0], + max_seq_len=5), + dict( + testcase_name="shuffle_dyck3_1", + program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), + vocab={"(", ")", "{", "}", "[", "]"}, + test_input=list("({)[}]"), + expected_output=[1, 1, 1, 1, 1, 1], + max_seq_len=6), + dict( + testcase_name="shuffle_dyck3_2", + program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), + vocab={"(", ")", "{", "}", "[", "]"}, + test_input=list("(){)}"), + expected_output=[0, 0, 0, 0, 0], + max_seq_len=6), + dict( + testcase_name="shuffle_dyck3_3", + program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), + vocab={"(", ")", "{", "}", "[", "]"}, + test_input=list("{}[(]"), + expected_output=[0, 0, 0, 0, 0], + max_seq_len=6), + dict( + testcase_name="hist", + program=lib.make_hist(), + vocab={"a", "b", "c", "d"}, + test_input=list("abac"), + expected_output=[2, 1, 2, 1], + max_seq_len=5, + ), + dict( + testcase_name="sort_unique_1", + program=lib.make_sort_unique(vals=rasp.tokens, keys=rasp.tokens), + vocab={1, 2, 3, 4}, + test_input=[2, 4, 3, 1], + expected_output=[1, 2, 3, 4], + max_seq_len=5), + dict( + testcase_name="sort_unique_2", + program=lib.make_sort_unique(vals=rasp.tokens, keys=1 - rasp.indices), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=["d", "c", "b", "a"], + max_seq_len=5), + dict( + testcase_name="sort_1", + program=lib.make_sort( + vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1), + vocab={1, 2, 3, 4}, + test_input=[2, 4, 3, 1], + expected_output=[1, 2, 3, 4], + max_seq_len=5), + dict( + testcase_name="sort_2", + program=lib.make_sort( + vals=rasp.tokens, keys=1 - rasp.indices, max_seq_len=5, min_key=1), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=["d", "c", "b", "a"], + max_seq_len=5), + dict( + testcase_name="sort_3", + program=lib.make_sort( + vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1), + vocab={1, 2, 3, 4}, + test_input=[2, 4, 1, 2], + expected_output=[1, 2, 2, 4], + max_seq_len=5), + dict( + testcase_name="sort_freq", + program=lib.make_sort_freq(max_seq_len=5), + vocab={1, 2, 3, 4}, + test_input=[2, 4, 2, 1], + expected_output=[2, 2, 4, 1], + max_seq_len=5), + dict( + testcase_name="make_count_less_freq_categorical_1", + program=lib.make_count_less_freq(n=2), + vocab={"a", "b", "c", "d"}, + test_input=["a", "a", "a", "b", "b", "c"], + expected_output=[3, 3, 3, 3, 3, 3], + max_seq_len=6), + dict( + testcase_name="make_count_less_freq_categorical_2", + program=lib.make_count_less_freq(n=2), + vocab={"a", "b", "c", "d"}, + test_input=["a", "a", "c", "b", "b", "c"], + expected_output=[6, 6, 6, 6, 6, 6], + max_seq_len=6), + dict( + testcase_name="make_count_less_freq_numerical_1", + program=rasp.numerical(lib.make_count_less_freq(n=2)), + vocab={"a", "b", "c", "d"}, + test_input=["a", "a", "a", "b", "b", "c"], + expected_output=[3, 3, 3, 3, 3, 3], + max_seq_len=6), + dict( + testcase_name="make_count_less_freq_numerical_2", + program=rasp.numerical(lib.make_count_less_freq(n=2)), + vocab={"a", "b", "c", "d"}, + test_input=["a", "a", "c", "b", "b", "c"], + expected_output=[6, 6, 6, 6, 6, 6], + max_seq_len=6), + dict( + testcase_name="make_count_1", + program=lib.make_count(rasp.tokens, "a"), + vocab={"a", "b", "c"}, + test_input=["a", "a", "a", "b", "b", "c"], + expected_output=[3, 3, 3, 3, 3, 3], + max_seq_len=8, + ), + dict( + testcase_name="make_count_2", + program=lib.make_count(rasp.tokens, "a"), + vocab={"a", "b", "c"}, + test_input=["c", "a", "b", "c"], + expected_output=[1, 1, 1, 1], + max_seq_len=8, + ), + dict( + testcase_name="make_count_3", + program=lib.make_count(rasp.tokens, "a"), + vocab={"a", "b", "c"}, + test_input=["b", "b", "c"], + expected_output=[0, 0, 0], + max_seq_len=8, + ), + dict( + testcase_name="make_nary_sequencemap_1", + program=lib.make_nary_sequencemap( + lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices), + vocab={1, 2, 3}, + test_input=[1, 2, 3], + expected_output=[2, 3, 4], + max_seq_len=5, + ), + dict( + testcase_name="make_nary_sequencemap_2", + program=lib.make_nary_sequencemap( + lambda x, y, z: x * y / z, rasp.indices, rasp.indices, rasp.tokens), + vocab={1, 2, 3}, + test_input=[1, 2, 3], + expected_output=[0, 1 / 2, 4 / 3], + max_seq_len=3, + ) +] + +# make_nary_sequencemap(f, *sops) + +CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [ + dict( + testcase_name="selector_width", + program=rasp.SelectorWidth( + rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)), + vocab={"a", "b", "c", "d"}, + test_input=list("abcd"), + expected_output=[1, 2, 3, 4], + max_seq_len=5), +] diff --git a/craft/bases.py b/craft/bases.py new file mode 100644 index 0000000000000000000000000000000000000000..215636ac106cb4b75bb638f797f9b4090737eac8 --- /dev/null +++ b/craft/bases.py @@ -0,0 +1,247 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Vectors and bases.""" + +import dataclasses +from typing import Sequence, Union, Optional, Iterable + +import numpy as np + +Name = Union[int, str] +Value = Union[int, float, bool, str, tuple] + + +@dataclasses.dataclass(frozen=True) +class BasisDirection: + """Represents a basis direction (no magnitude) in a vector space. + + Attributes: + name: a unique name for this direction. + value: used to hold a value one-hot-encoded by this direction. e.g., + [BasisDirection("vs_1", True), BasisDirection("vs_1", False)] would be + basis directions of a subspace called "vs_1" which one-hot-encodes the + values True and False. If provided, considered part of the name for the + purpose of disambiguating directions. + """ + name: Name + value: Optional[Value] = None + + def __str__(self): + if self.value is None: + return str(self.name) + return f"{self.name}:{self.value}" + + def __lt__(self, other: "BasisDirection") -> bool: + try: + return (self.name, self.value) < (other.name, other.value) + except TypeError: + return str(self) < str(other) + + +@dataclasses.dataclass +class VectorInBasis: + """A vector (or array of vectors) in a given basis. + + When magnitudes are 1-d, this is a vector. + When magnitudes are (n+1)-d, this is an array of vectors, + where the -1th dimension is the basis dimension. + """ + basis_directions: Sequence[BasisDirection] + magnitudes: np.ndarray + + def __post_init__(self): + """Sort basis directions.""" + if len(self.basis_directions) != self.magnitudes.shape[-1]: + raise ValueError( + "Last dimension of magnitudes must be the same as number " + f"of basis directions. Was {len(self.basis_directions)} " + f"and {self.magnitudes.shape[-1]}.") + + sort_idx = np.argsort(self.basis_directions) + self.basis_directions = [self.basis_directions[i] for i in sort_idx] + self.magnitudes = np.take(self.magnitudes, sort_idx, -1) + + def __add__(self, other: "VectorInBasis") -> "VectorInBasis": + if self.basis_directions != other.basis_directions: + raise TypeError(f"Adding incompatible bases: {self} + {other}") + magnitudes = self.magnitudes + other.magnitudes + return VectorInBasis(self.basis_directions, magnitudes) + + def __radd__(self, other: "VectorInBasis") -> "VectorInBasis": + if self.basis_directions != other.basis_directions: + raise TypeError(f"Adding incompatible bases: {other} + {self}") + return self + other + + def __sub__(self, other: "VectorInBasis") -> "VectorInBasis": + if self.basis_directions != other.basis_directions: + raise TypeError(f"Subtracting incompatible bases: {self} - {other}") + magnitudes = self.magnitudes - other.magnitudes + return VectorInBasis(self.basis_directions, magnitudes) + + def __rsub__(self, other: "VectorInBasis") -> "VectorInBasis": + if self.basis_directions != other.basis_directions: + raise TypeError(f"Subtracting incompatible bases: {other} - {self}") + magnitudes = other.magnitudes - self.magnitudes + return VectorInBasis(self.basis_directions, magnitudes) + + def __mul__(self, scalar: float) -> "VectorInBasis": + return VectorInBasis(self.basis_directions, self.magnitudes * scalar) + + def __rmul__(self, scalar: float) -> "VectorInBasis": + return self * scalar + + def __truediv__(self, scalar: float) -> "VectorInBasis": + return VectorInBasis(self.basis_directions, self.magnitudes / scalar) + + def __neg__(self) -> "VectorInBasis": + return (-1) * self + + def __eq__(self, other: "VectorInBasis") -> bool: + return ((self.basis_directions == other.basis_directions) and + (self.magnitudes.shape == other.magnitudes.shape) and + (np.all(self.magnitudes == other.magnitudes))) + + @classmethod + def sum(cls, vectors: Sequence["VectorInBasis"]) -> "VectorInBasis": + return cls(vectors[0].basis_directions, + np.sum([x.magnitudes for x in vectors], axis=0)) + + @classmethod + def stack(cls, + vectors: Sequence["VectorInBasis"], + axis: int = 0) -> "VectorInBasis": + for v in vectors[1:]: + if v.basis_directions != vectors[0].basis_directions: + raise TypeError(f"Stacking incompatible bases: {vectors[0]} + {v}") + return cls(vectors[0].basis_directions, + np.stack([v.magnitudes for v in vectors], axis=axis)) + + def project( + self, basis: Union["VectorSpaceWithBasis", Sequence[BasisDirection]] + ) -> "VectorInBasis": + """Projects to the basis.""" + if isinstance(basis, VectorSpaceWithBasis): + basis = basis.basis + components = [] + for direction in basis: + if direction in self.basis_directions: + components.append( + self.magnitudes[..., self.basis_directions.index(direction)]) + else: + components.append(np.zeros_like(self.magnitudes[..., 0])) + return VectorInBasis(list(basis), np.stack(components, axis=-1)) + + +@dataclasses.dataclass +class VectorSpaceWithBasis: + """A vector subspace in a given basis.""" + basis: Sequence[BasisDirection] + + def __post_init__(self): + """Keep basis directions sorted.""" + self.basis = sorted(self.basis) + + @property + def num_dims(self) -> int: + return len(self.basis) + + def __contains__(self, item: Union[VectorInBasis, BasisDirection]) -> bool: + if isinstance(item, BasisDirection): + return item in self.basis + + return set(self.basis) == set(item.basis_directions) + + def issubspace(self, other: "VectorSpaceWithBasis") -> bool: + return set(self.basis).issubset(set(other.basis)) + + def basis_vectors(self) -> Sequence[VectorInBasis]: + basis_vector_magnitudes = list(np.eye(self.num_dims)) + return [VectorInBasis(self.basis, m) for m in basis_vector_magnitudes] + + def vector_from_basis_direction( + self, basis_direction: BasisDirection) -> VectorInBasis: + i = self.basis.index(basis_direction) + return VectorInBasis(self.basis, np.eye(self.num_dims)[i]) + + def null_vector(self) -> VectorInBasis: + return VectorInBasis(self.basis, np.zeros(self.num_dims)) + + @classmethod + def from_names(cls, names: Sequence[Name]) -> "VectorSpaceWithBasis": + """Creates a VectorSpace from a list of names for its basis directions.""" + return cls([BasisDirection(n) for n in names]) + + @classmethod + def from_values( + cls, + name: Name, + values: Iterable[Value], + ) -> "VectorSpaceWithBasis": + """Creates a VectorSpace from a list of values for its basis directions.""" + return cls([BasisDirection(name, v) for v in values]) + + +def direct_sum(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis: + """Create a direct sum of the vector spaces. + + Assumes the basis elements of all input vector spaces are + orthogonal to each other. Maintains the order of the bases. + + Args: + *vs: the vector spaces to sum. + + Returns: + the combined vector space. + + Raises: + Value error in case of overlapping bases. + """ + # Take the union of all the bases: + total_basis = sum([v.basis for v in vs], []) + + if len(total_basis) != len(set(total_basis)): + raise ValueError("Overlapping bases!") + + return VectorSpaceWithBasis(total_basis) + + +def join_vector_spaces(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis: + """Joins a set of vector spaces allowing them to overlap. + + Assumes the basis elements of all input vector spaces are + orthogonal to each other. Does not maintain the order of the bases but + sorts them. + + Args: + *vs: the vector spaces to sum. + + Returns: + the combined vector space. + """ + # Take the union of all the bases: + total_basis = list(set().union(*[set(v.basis) for v in vs])) + total_basis = sorted(total_basis) + return VectorSpaceWithBasis(total_basis) + + +def ensure_dims( + vs: VectorSpaceWithBasis, + num_dims: int, + name: str = "vector space", +) -> None: + """Raises ValueError if vs has the wrong number of dimensions.""" + if vs.num_dims != num_dims: + raise ValueError(f"{name} must have {num_dims=}, " + f"but got {vs.num_dims}: {vs.basis}") diff --git a/craft/bases_test.py b/craft/bases_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf6521861b98e7ee949cbdcc53ac63f0418f59f --- /dev/null +++ b/craft/bases_test.py @@ -0,0 +1,158 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bases.""" + +from absl.testing import absltest +import numpy as np +from tracr.craft import bases +from tracr.craft import tests_common + + +class VectorInBasisTest(tests_common.VectorFnTestCase): + + def test_shape_mismatch_raises_value_error(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + regex = (r"^.*Last dimension of magnitudes must be the same as number of " + r"basis directions.*$") + with self.assertRaisesRegex(ValueError, regex): + bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) + with self.assertRaisesRegex(ValueError, regex): + bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) + + def test_equal(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) + v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) + self.assertEqual(v1, v2) + self.assertEqual(v2, v1) + v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) + v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) + self.assertEqual(v3, v4) + self.assertEqual(v4, v3) + v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) + v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1])) + self.assertNotEqual(v5, v6) + self.assertNotEqual(v6, v5) + v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) + v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]])) + self.assertNotEqual(v7, v8) + self.assertNotEqual(v8, v7) + vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"]) + v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) + v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4])) + self.assertNotEqual(v9, v10) + self.assertNotEqual(v10, v9) + + def test_dunders(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) + v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) + three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) + five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5])) + v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10])) + self.assertEqual(5 * v, v_times_5) + self.assertEqual(v * 5, v_times_5) + self.assertEqual(5.0 * v, v_times_5) + self.assertEqual(v * 5.0, v_times_5) + v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1])) + self.assertEqual(v / 2, v_by_2) + self.assertEqual(v / 2.0, v_by_2) + self.assertEqual(1 / 2 * v, v_by_2) + v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) + self.assertEqual(v + three, v_plus_3) + self.assertEqual(three + v, v_plus_3) + v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3])) + self.assertEqual(v - five, v_minus_5) + minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2])) + self.assertEqual(-v, minus_v) + + +class ProjectionTest(tests_common.VectorFnTestCase): + + def test_direct_sum_produces_expected_result(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) + vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"]) + self.assertEqual(bases.direct_sum(vs1, vs2), vs3) + + def test_join_vector_spaces_produces_expected_result(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) + vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) + + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"]) + vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) + + def test_compare_vectors_with_differently_ordered_basis_vectors(self): + basis1 = ["a", "b", "c", "d"] + basis1 = [bases.BasisDirection(x) for x in basis1] + basis2 = ["b", "d", "a", "c"] + basis2 = [bases.BasisDirection(x) for x in basis2] + vs1 = bases.VectorSpaceWithBasis(basis1) + vs2 = bases.VectorSpaceWithBasis(basis2) + v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4])) + v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3])) + self.assertEqual(v1, v2) + self.assertEqual(v1 - v2, vs1.null_vector()) + self.assertEqual(v1 - v2, vs2.null_vector()) + self.assertEqual(v1 + v2, 2 * v2) + self.assertIn(v1, vs1) + self.assertIn(v1, vs2) + self.assertIn(v2, vs1) + self.assertIn(v2, vs2) + + def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self): + basis1 = ["a", "b", "c", "d"] + basis1 = [bases.BasisDirection(x) for x in basis1] + basis2 = ["b", "d", "a", "c"] + basis2 = [bases.BasisDirection(x) for x in basis2] + vs1 = bases.VectorSpaceWithBasis(basis1) + vs2 = bases.VectorSpaceWithBasis(basis2) + v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]])) + v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]])) + null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()]) + self.assertEqual(v1, v2) + self.assertEqual(v1 - v2, null_vec) + self.assertEqual(v1 + v2, 2 * v2) + self.assertIn(v1, vs1) + self.assertIn(v1, vs2) + self.assertIn(v2, vs1) + self.assertIn(v2, vs2) + + def test_projection_to_larger_space(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + a1, b1 = vs1.basis_vectors() + a2, b2, _, _ = vs2.basis_vectors() + + self.assertEqual(a1.project(vs2), a2) + self.assertEqual(b1.project(vs2), b2) + + def test_projection_to_smaller_space(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a1, b1, c1, d1 = vs1.basis_vectors() + a2, b2 = vs2.basis_vectors() + + self.assertEqual(a1.project(vs2), a2) + self.assertEqual(b1.project(vs2), b2) + self.assertEqual(c1.project(vs2), vs2.null_vector()) + self.assertEqual(d1.project(vs2), vs2.null_vector()) + + +if __name__ == "__main__": + absltest.main() diff --git a/craft/chamber/categorical_attn.py b/craft/chamber/categorical_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..7c52920c2aa91cb8fc7cc266a9d40ba053543c65 --- /dev/null +++ b/craft/chamber/categorical_attn.py @@ -0,0 +1,167 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Attention head for categorical inputs.""" + +from typing import Optional, Protocol + +from tracr.craft import bases +from tracr.craft import transformers +from tracr.craft import vectorspace_fns + + +class QueryKeyToAttnLogit(Protocol): + + def __call__(self, query: bases.BasisDirection, + key: bases.BasisDirection) -> bool: + pass + + +def categorical_attn( + query_space: bases.VectorSpaceWithBasis, + key_space: bases.VectorSpaceWithBasis, + value_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + bos_space: bases.VectorSpaceWithBasis, + one_space: bases.VectorSpaceWithBasis, + attn_fn: QueryKeyToAttnLogit, + default_output: Optional[bases.VectorInBasis] = None, + causal: bool = False, + always_attend_to_bos: bool = False, + use_bos_for_default_output: bool = True, + softmax_coldness: float = 100., +) -> transformers.AttentionHead: + """Returns an attention head for categorical inputs. + + Assumes the existence of a beginning of sequence token and attends to it + always with strength 0.5*softmax_coldness. This allows to implement an + arbitrary default value for rows in the attention pattern that are all-zero. + + Attends to the BOS token if all other key-query pairs have zero attention. + Hence, the first value in the value sequence will be the default output for + such cases. + + Args: + query_space: Vector space containing (categorical) query input. + key_space: Vector space containing (categorical) key input. + value_space: Vector space containing (numerical) value input. + output_space: Vector space which will contain (numerical) output. + bos_space: 1-d space used to identify the beginning of sequence token. + one_space: 1-d space which contains 1 at every position. + attn_fn: A selector function f(query, key) operating on the query/key basis + directions that defines the attention pattern. + default_output: Output to return if attention pattern is all zero. + causal: If True, use masked attention. + always_attend_to_bos: If True, always attend to the BOS token. If False, + only attend to BOS when attending to nothing else. + use_bos_for_default_output: If True, assume BOS is not in the value space + and output a default value when attending to BOS. If False, assume BOS is + in the value space, and map it to the output space like any other token. + softmax_coldness: The inverse temperature of the softmax. Default value is + high which makes the attention close to a hard maximum. + """ + bases.ensure_dims(bos_space, num_dims=1, name="bos_space") + bases.ensure_dims(one_space, num_dims=1, name="one_space") + bos_direction = bos_space.basis[0] + one_direction = one_space.basis[0] + + # Add bos direction to query, key, and value spaces in case it is missing + query_space = bases.join_vector_spaces(query_space, bos_space, one_space) + key_space = bases.join_vector_spaces(key_space, bos_space) + value_space = bases.join_vector_spaces(value_space, bos_space) + + if always_attend_to_bos: + value_basis = value_space.basis + else: + value_basis = [v for v in value_space.basis if v != bos_direction] + assert len(value_basis) == output_space.num_dims + value_to_output = dict(zip(value_basis, output_space.basis)) + + if default_output is None: + default_output = output_space.null_vector() + assert default_output in output_space + + def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float: + + # We want to enforce the following property on our attention patterns: + # - if nothing else is attended to, attend to the BOS token. + # - otherwise, don't attend to the BOS token. + # + # We assume that the BOS position always only contains the vector bos + one, + # and that any other position has bos coefficient 0. + # + # We do this as follows: + # Let Q and K be subspaces of V containing the query and key vectors, + # both disjoint with the BOS space {bos} or the one space {one}. + # Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> ℝ, + # s.t. W_QK(q, k) = 0 whenever either q or k are bos or one. + # + # Then define W_new: V x V -> ℝ st: + # W_new(one, bos) = 0.5, otherwise 0. + # + # Now set W_QK' = W_QK + W_new. + # + # To evaluate the attention to the BOS position: + # W_QK'(q, bos + one) + # = W_QK'(q, bos) + W_QK'(q, one) + # = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one) + # = 0 + 0 + W_new(q, bos) + W_new(q, one) + # = W_new(q, bos) + W_new(q, one) + # = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q' + # = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one) + # = 0 + 0.5 + 0 + 0 + # = 0.5 + # + # To evaluate the attention to a non-BOS position: + # W_QK'(0 * bos + q, 0 * bos + k) # s.t. q ∈ Q+{one}, k ∈ K+{one} + # = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k) + # = W_QK'(q, 0*bos + k) + # = 0*W_QK'(q, bos) + W_QK'(q, k) + # = W_QK'(q, k) + # = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos. + # = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one. + # + # Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax + # coldness will give us the desired property. QED + # + # The following implements this idea. + # By replacing 0.5 with 1, we can instead enforce a different property: that + # the BOS token is always attended to in addition to whatever else. + + if key == bos_direction and query == one_direction: + c = 1. if always_attend_to_bos else 0.5 + return c * softmax_coldness + elif {key, query}.intersection({one_direction, bos_direction}): + return 0 + + return softmax_coldness * attn_fn(query, key) + + w_qk = vectorspace_fns.ScalarBilinear.from_action( + query_space, + key_space, + qk_fun, + ) + + def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis: + if use_bos_for_default_output and input_dir == bos_direction: + return default_output + return output_space.vector_from_basis_direction(value_to_output[input_dir]) + + w_ov = vectorspace_fns.Linear.from_action( + value_space, + output_space, + ov_fun, + ) + + return transformers.AttentionHead(w_qk, w_ov, causal=causal) diff --git a/craft/chamber/categorical_attn_test.py b/craft/chamber/categorical_attn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..83320b4e770527b832c10dc14b3a0c9c3ae33b7c --- /dev/null +++ b/craft/chamber/categorical_attn_test.py @@ -0,0 +1,229 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for chamber.categorical_attn.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.craft.chamber import categorical_attn + + +class CategoricalAttnTest(tests_common.VectorFnTestCase): + + @parameterized.parameters([ + dict(causal=False, input_seq=[1, 2, 3, 4, 5], result_seq=[3, 3, 3, 3, 3]), + dict( + causal=True, + input_seq=[1, 2, 3, 4, 5], + result_seq=[1, 1.5, 2, 2.5, 3]), + dict(causal=False, input_seq=[10], result_seq=[10]), + dict(causal=True, input_seq=[10], result_seq=[10]), + dict(causal=False, input_seq=[-1, 0, 1], result_seq=[0, 0, 0]), + dict(causal=True, input_seq=[-1, 0, 1], result_seq=[-1, -0.5, 0]), + ]) + def test_categorical_attn_can_implement_select_all(self, causal, input_seq, + result_seq): + vocab = range(-20, 20) + input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) + + output_dir = bases.BasisDirection("output") + output_space = bases.VectorSpaceWithBasis([output_dir]) + output_vec = output_space.vector_from_basis_direction(output_dir) + + bos_dir = bases.BasisDirection("bos_dimension") + bos_space = bases.VectorSpaceWithBasis([bos_dir]) + + one_dir = bases.BasisDirection("one") + one_space = bases.VectorSpaceWithBasis([one_dir]) + + value_dir = bases.BasisDirection("value") + value_space = bases.VectorSpaceWithBasis([value_dir]) + + input_space = bases.join_vector_spaces(input_space, bos_space, one_space) + value_space = bases.join_vector_spaces(value_space, bos_space) + residual_space = bases.join_vector_spaces(input_space, value_space, + output_space) + one_vec = residual_space.vector_from_basis_direction(one_dir) + bos_vec = residual_space.vector_from_basis_direction(bos_dir) + value_vec = residual_space.vector_from_basis_direction(value_dir) + + attn = categorical_attn.categorical_attn( + key_space=input_space, + query_space=input_space, + value_space=value_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=lambda x, y: True, + causal=causal) + + test_inputs = [bos_vec + one_vec] + for x in input_seq: + test_inputs.append( + residual_space.vector_from_basis_direction( + bases.BasisDirection("input", x)) + x * value_vec) + test_inputs = bases.VectorInBasis.stack(test_inputs) + + # Expect the average of all (previous) tokens + expected_results = [x * output_vec for x in result_seq] + expected_results = bases.VectorInBasis.stack(expected_results) + + test_outputs = attn.apply(test_inputs).project(output_space) + + self.assertVectorAllClose( + tests_common.strip_bos_token(test_outputs), expected_results) + + @parameterized.parameters([ + dict(causal=False, input_seq=[1, 2, 3, 4, 5], default=0), + dict(causal=True, input_seq=[1, 2, 3, 4, 5], default=1), + dict(causal=False, input_seq=[10], default=2), + dict(causal=True, input_seq=[10], default=-3), + dict(causal=False, input_seq=[-1, 0, 1], default=-2), + dict(causal=True, input_seq=[-1, 0, 1], default=-1), + ]) + def test_categorical_attn_can_implement_select_none(self, causal, input_seq, + default): + vocab = range(-20, 20) + input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) + + output_dir = bases.BasisDirection("output") + output_space = bases.VectorSpaceWithBasis([output_dir]) + default_vec = default * output_space.vector_from_basis_direction(output_dir) + + bos_dir = bases.BasisDirection("bos_dimension") + bos_space = bases.VectorSpaceWithBasis([bos_dir]) + + one_dir = bases.BasisDirection("one") + one_space = bases.VectorSpaceWithBasis([one_dir]) + + value_dir = bases.BasisDirection("value") + value_space = bases.VectorSpaceWithBasis([value_dir]) + + input_space = bases.join_vector_spaces(input_space, bos_space, one_space) + value_space = bases.join_vector_spaces(value_space, bos_space) + residual_space = bases.join_vector_spaces(input_space, value_space, + output_space) + value_vec = residual_space.vector_from_basis_direction(value_dir) + bos_vec = residual_space.vector_from_basis_direction(bos_dir) + one_vec = residual_space.vector_from_basis_direction(one_dir) + + attn = categorical_attn.categorical_attn( + key_space=input_space, + query_space=input_space, + value_space=value_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=lambda x, y: False, + default_output=default_vec, + causal=causal, + always_attend_to_bos=False, + use_bos_for_default_output=True) + + def make_input(x): + return (one_vec + x * value_vec + + residual_space.vector_from_basis_direction( + bases.BasisDirection("input", x))) + + test_inputs = bases.VectorInBasis.stack([bos_vec + one_vec] + + [make_input(x) for x in input_seq]) + + # Expect the default value + expected_results = [default_vec for x in input_seq] + expected_results = bases.VectorInBasis.stack(expected_results) + + test_outputs = attn.apply(test_inputs).project(output_space) + + self.assertVectorAllClose( + tests_common.strip_bos_token(test_outputs), expected_results) + + @parameterized.parameters([ + dict(num_counts=5, input_seq=[1, 4, 3, 2], n=1, result=[4, 3, 2, 1]), + dict(num_counts=10, input_seq=[5, 8, 9, 2], n=3, result=[2, 5, 8, 9]) + ]) + def test_categorical_attn_can_implement_shift_by_n(self, num_counts, + input_seq, n, result): + query_prefix = "prefix1" + key_prefix = "prefix2" + agg_input_prefix = "prefix3" + output_prefix = "prefix4" + + bos_direction = bases.BasisDirection("bos") + one_direction = bases.BasisDirection("one") + query_space = bases.VectorSpaceWithBasis.from_values( + query_prefix, range(num_counts)) + key_space = bases.VectorSpaceWithBasis.from_values(key_prefix, + range(num_counts)) + bos_space = bases.VectorSpaceWithBasis([bos_direction]) + one_space = bases.VectorSpaceWithBasis([one_direction]) + key_space = bases.join_vector_spaces(key_space, bos_space) + + agg_input_space = bases.VectorSpaceWithBasis.from_values( + agg_input_prefix, range(num_counts)) + agg_input_space = bases.join_vector_spaces(agg_input_space, bos_space) + output_space = bases.VectorSpaceWithBasis.from_values( + output_prefix, range(num_counts)) + + attn = categorical_attn.categorical_attn( + query_space=query_space, + key_space=key_space, + value_space=agg_input_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=lambda q, k: q.value == k.value, + default_output=None, + always_attend_to_bos=False, + use_bos_for_default_output=True, + causal=False) + + residual_space = bases.join_vector_spaces(key_space, query_space, + agg_input_space, output_space, + one_space) + + seq_len = len(input_seq) + query_seq = np.arange(n, seq_len + n) % seq_len + key_seq = np.arange(seq_len) + + bos_vec = residual_space.vector_from_basis_direction(bos_direction) + one_vec = residual_space.vector_from_basis_direction(one_direction) + + test_inputs = [bos_vec + one_vec] + expected_results = [] + for i in range(seq_len): + test_inputs.append( + residual_space.vector_from_basis_direction( + bases.BasisDirection(query_prefix, query_seq[i])) + + residual_space.vector_from_basis_direction( + bases.BasisDirection(key_prefix, key_seq[i])) + + residual_space.vector_from_basis_direction( + bases.BasisDirection(agg_input_prefix, input_seq[i]))) + expected_results.append( + residual_space.vector_from_basis_direction( + bases.BasisDirection(output_prefix, result[i]))) + + test_inputs = bases.VectorInBasis.stack(test_inputs) + expected_results = bases.VectorInBasis.stack(expected_results) + + test_outputs = attn.apply(test_inputs) + + self.assertVectorAllClose( + tests_common.strip_bos_token(test_outputs), expected_results) + + +if __name__ == "__main__": + absltest.main() diff --git a/craft/chamber/categorical_mlp.py b/craft/chamber/categorical_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6b15278f2e1e62850ff5bb82dfcd977079d832a9 --- /dev/null +++ b/craft/chamber/categorical_mlp.py @@ -0,0 +1,168 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MLP to compute basic linear functions of one-hot encoded integers.""" + +from typing import Callable + +import numpy as np + +from tracr.craft import bases +from tracr.craft import transformers +from tracr.craft import vectorspace_fns + +_ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"]) + + +def map_categorical_mlp( + input_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + operation: Callable[[bases.BasisDirection], bases.BasisDirection], +) -> transformers.MLP: + """Returns an MLP that encodes any categorical function of a single variable f(x). + + The hidden layer is the identity and output combines this with a lookup table + output_k = sum(f(i)*input_i for all i in input space) + + Args: + input_space: space containing the input x. + output_space: space containing possible outputs. + operation: A function operating on basis directions. + """ + + def operation_fn(direction): + if direction in input_space: + output_direction = operation(direction) + if output_direction in output_space: + return output_space.vector_from_basis_direction(output_direction) + return output_space.null_vector() + + first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, + operation_fn) + + second_layer = vectorspace_fns.project(output_space, output_space) + + return transformers.MLP(first_layer, second_layer) + + +def map_categorical_to_numerical_mlp( + input_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + operation: Callable[[bases.Value], float], +) -> transformers.MLP: + """Returns an MLP to compute f(x) from a categorical to a numerical variable. + + The hidden layer is the identity and output combines this with a lookup table + output = sum(f(i)*input_i for all i in input space) + + Args: + input_space: Vector space containing the input x. + output_space: Vector space to write the numerical output to. + operation: A function operating on basis directions. + """ + bases.ensure_dims(output_space, num_dims=1, name="output_space") + out_vec = output_space.vector_from_basis_direction(output_space.basis[0]) + + def operation_fn(direction): + if direction in input_space: + return operation(direction.value) * out_vec + return output_space.null_vector() + + first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, + operation_fn) + + second_layer = vectorspace_fns.project(output_space, output_space) + + return transformers.MLP(first_layer, second_layer) + + +def sequence_map_categorical_mlp( + input1_space: bases.VectorSpaceWithBasis, + input2_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + operation: Callable[[bases.BasisDirection, bases.BasisDirection], + bases.BasisDirection], + one_space: bases.VectorSpaceWithBasis = _ONE_SPACE, + hidden_name: bases.Name = "__hidden__", +) -> transformers.MLP: + """Returns an MLP that encodes a categorical function of two variables f(x, y). + + The hidden layer of the MLP computes the logical and of all input directions + hidden_i_j = ReLU(x_i+x_j-1) + + And the output combines this with a lookup table + output_k = sum(f(i, j)*hidden_i_j for all i,j in input space) + + Args: + input1_space: Vector space containing the input x. + input2_space: Vector space containing the input y. + output_space: Vector space to write outputs to. + operation: A function operating on basis directions. + one_space: a reserved 1-d space that always contains a 1. + hidden_name: Name for hidden dimensions. + """ + bases.ensure_dims(one_space, num_dims=1, name="one_space") + + if not set(input1_space.basis).isdisjoint(input2_space.basis): + raise ValueError("Input spaces to a SequenceMap must be disjoint. " + "If input spaces are the same, use Map instead!") + + input_space = bases.direct_sum(input1_space, input2_space, one_space) + + def to_hidden(x, y): + return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value)) + + def from_hidden(h): + x_name, x_value, y_name, y_value = h.value + x_dir = bases.BasisDirection(x_name, x_value) + y_dir = bases.BasisDirection(y_name, y_value) + return x_dir, y_dir + + hidden_dir = [] + for dir1 in input1_space.basis: + for dir2 in input2_space.basis: + hidden_dir.append(to_hidden(dir1, dir2)) + hidden_space = bases.VectorSpaceWithBasis(hidden_dir) + + def logical_and(direction): + if direction in one_space: + out = bases.VectorInBasis(hidden_space.basis, + -np.ones(hidden_space.num_dims)) + elif direction in input1_space: + dir1 = direction + out = hidden_space.null_vector() + for dir2 in input2_space.basis: + out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2)) + else: + dir2 = direction + out = hidden_space.null_vector() + for dir1 in input1_space.basis: + out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2)) + return out + + first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space, + logical_and) + + def operation_fn(direction): + dir1, dir2 = from_hidden(direction) + output_direction = operation(dir1, dir2) + if output_direction in output_space: + return output_space.vector_from_basis_direction(output_direction) + else: + return output_space.null_vector() + + second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space, + operation_fn) + + return transformers.MLP(first_layer, second_layer) diff --git a/craft/chamber/categorical_mlp_test.py b/craft/chamber/categorical_mlp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcbee399dcb00e0e935ca31aa7eade303a262d5 --- /dev/null +++ b/craft/chamber/categorical_mlp_test.py @@ -0,0 +1,164 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for chamber.categorical_mlp.""" + +import math +from absl.testing import absltest +from absl.testing import parameterized + +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.craft.chamber import categorical_mlp + + +class CategoricalInputMlpTest(tests_common.VectorFnTestCase): + + @parameterized.parameters([ + dict(num_counts=4, x=1, y=2, fun=lambda x, y: x + y, result=3), + dict(num_counts=4, x=1, y=0, fun=lambda x, y: x + y + 1, result=2), + dict(num_counts=5, x=2, y=1, fun=math.pow, result=2), + dict(num_counts=5, x=2, y=2, fun=math.pow, result=4), + ]) + def test_seq_map_categorical_mlp_produces_expected_outcome( + self, num_counts, x, y, fun, result): + input1_name = "in1" + input2_name = "in2" + output_name = "out" + one_name = "one_dimension" + + in1_space = bases.VectorSpaceWithBasis.from_values(input1_name, + range(num_counts + 1)) + in2_space = bases.VectorSpaceWithBasis.from_values(input2_name, + range(num_counts + 1)) + out_space = bases.VectorSpaceWithBasis.from_values(output_name, + range(num_counts + 1)) + + def operation(in1, in2): + out_val = fun(int(in1.value), int(in2.value)) + return bases.BasisDirection(output_name, out_val) + + mlp = categorical_mlp.sequence_map_categorical_mlp( + input1_space=in1_space, + input2_space=in2_space, + output_space=out_space, + operation=operation, + one_space=bases.VectorSpaceWithBasis.from_names([one_name])) + + test_inputs = ( + mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection(one_name)) + + mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection(input1_name, x)) + + mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection(input2_name, y))) + + expected_results = mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection(output_name, result)) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + def test_seq_map_categorical_mlp_raises_error_with_overlapping_inputs(self): + input_name = "in" + output_name = "out" + one_name = "one_dimension" + + in1_space = bases.VectorSpaceWithBasis.from_values(input_name, range(5)) + in2_space = bases.VectorSpaceWithBasis.from_values(input_name, range(3, 10)) + out_space = bases.VectorSpaceWithBasis.from_values(output_name, range(5)) + + with self.assertRaisesRegex( + ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"): + categorical_mlp.sequence_map_categorical_mlp( + input1_space=in1_space, + input2_space=in1_space, + output_space=out_space, + operation=lambda x, y: bases.BasisDirection(output_name, 0), + one_space=bases.VectorSpaceWithBasis.from_names([one_name])) + + with self.assertRaisesRegex( + ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"): + categorical_mlp.sequence_map_categorical_mlp( + input1_space=in1_space, + input2_space=in2_space, + output_space=out_space, + operation=lambda x, y: bases.BasisDirection(output_name, 0), + one_space=bases.VectorSpaceWithBasis.from_names([one_name])) + + @parameterized.parameters([ + dict(num_counts=5, x=2, fun=lambda x: x, result=2), + dict(num_counts=5, x=2, fun=lambda x: math.pow(x, int(2)), result=4), + dict(num_counts=5, x=-2, fun=lambda x: math.pow(x, int(2)), result=4), + dict(num_counts=5, x=-1, fun=lambda x: math.pow(x, int(3)), result=-1), + ]) + def test_map_categorical_mlp_produces_expected_outcome_computing_powers( + self, num_counts, x, fun, result): + input_name = "in" + output_name = "out" + + in_space = bases.VectorSpaceWithBasis.from_values( + input_name, range(-num_counts, num_counts + 1)) + out_space = bases.VectorSpaceWithBasis.from_values( + output_name, range(-num_counts, num_counts + 1)) + + def operation(direction): + out_val = fun(int(direction.value)) + return bases.BasisDirection(output_name, out_val) + + mlp = categorical_mlp.map_categorical_mlp( + input_space=in_space, output_space=out_space, operation=operation) + + test_inputs = mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection(input_name, x)) + + expected_results = mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection(output_name, result)) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + @parameterized.parameters([ + dict(x=2, fun=lambda x: x, result=2), + dict(x=2, fun=lambda x: math.pow(x, int(2)), result=4), + dict(x=1, fun=lambda x: 1 / (x + 1), result=0.5), + dict(x=3, fun=lambda x: 1 / (x + 1), result=0.25), + ]) + def test_map_categorical_to_numerical_mlp_produces_expected_outcome( + self, x, fun, result): + + in_space = bases.VectorSpaceWithBasis.from_values("in", range(6)) + out_space = bases.VectorSpaceWithBasis.from_names(["out"]) + + mlp = categorical_mlp.map_categorical_to_numerical_mlp( + input_space=in_space, + output_space=out_space, + operation=fun, + ) + + test_inputs = mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection("in", x)) + + expected_results = result * mlp.residual_space.vector_from_basis_direction( + bases.BasisDirection("out")) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + +if __name__ == "__main__": + absltest.main() diff --git a/craft/chamber/numerical_mlp.py b/craft/chamber/numerical_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5a703c7ef9f2558ac8fb8df8fad1b99359b2da23 --- /dev/null +++ b/craft/chamber/numerical_mlp.py @@ -0,0 +1,334 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MLPs to compute arbitrary numerical functions by discretising.""" + +import dataclasses + +from typing import Callable, Iterable + +from tracr.craft import bases +from tracr.craft import transformers +from tracr.craft import vectorspace_fns +from tracr.utils import errors + + +@dataclasses.dataclass +class DiscretisingLayerMaterials: + """Provides components for a hidden layer that discretises the input. + + Attributes: + action: Function acting on basis directions that defines the computation. + hidden_space: Vector space of the hidden representation of the layer. + output_values: Set of output values that correspond to the discretisation. + """ + action: Callable[[bases.BasisDirection], bases.VectorInBasis] + hidden_space: bases.VectorSpaceWithBasis + output_values: list[float] + + +def _get_discretising_layer(input_value_set: Iterable[float], + f: Callable[[float], + float], hidden_name: bases.Name, + one_direction: bases.BasisDirection, + large_number: float) -> DiscretisingLayerMaterials: + """Creates a hidden layer that discretises the input of f(x) into a value set. + + The input is split up into a distinct region around each value in + `input_value_set`: + + elements of value set: v0 | v1 | v2 | v3 | v4 | ... + thresholds: t0 t1 t2 t3 t4 + + The hidden layer has two activations per threshold: + hidden_k_1 = ReLU(L * (x - threshold[k]) + 1) + hidden_k_2 = ReLU(L * (x - threshold[k])) + + Note that hidden_k_1 - hidden_k_2 is: + 1 if x >= threshold[k] + 1/L + 0 if x <= threshold[k] + between 0 and 1 if threshold[k] < x < threshold[k] + 1/L + + So as long as we choose L a big enough number, we have + hidden_k_1 - hidden_k_2 = 1 if x >= threshold[k]. + i.e. we know in which region the input value is. + + Args: + input_value_set: Set of discrete input values. + f: Function to approximate. + hidden_name: Name for hidden dimensions. + one_direction: Auxiliary dimension that must contain 1 in the input. + large_number: Large number L that determines accuracy of the computation. + + Returns: + DiscretisingLayerMaterials containing all components for the layer. + """ + output_values, sorted_values = [], [] + for x in sorted(input_value_set): + res = errors.ignoring_arithmetic_errors(f)(x) + if res is not None: + output_values.append(res) + sorted_values.append(x) + + num_vals = len(sorted_values) + value_thresholds = [ + (sorted_values[i] + sorted_values[i + 1]) / 2 for i in range(num_vals - 1) + ] + + hidden_directions = [bases.BasisDirection(f"{hidden_name}start")] + for k in range(1, num_vals): + dir0 = bases.BasisDirection(hidden_name, (k, 0)) + dir1 = bases.BasisDirection(hidden_name, (k, 1)) + hidden_directions.extend([dir0, dir1]) + hidden_space = bases.VectorSpaceWithBasis(hidden_directions) + + def action(direction: bases.BasisDirection) -> bases.VectorInBasis: + # hidden_k_0 = ReLU(L * (x - threshold[k]) + 1) + # hidden_k_1 = ReLU(L * (x - threshold[k])) + if direction == one_direction: + hidden = hidden_space.vector_from_basis_direction( + bases.BasisDirection(f"{hidden_name}start")) + else: + hidden = hidden_space.null_vector() + for k in range(1, num_vals): + vec0 = hidden_space.vector_from_basis_direction( + bases.BasisDirection(hidden_name, (k, 0))) + vec1 = hidden_space.vector_from_basis_direction( + bases.BasisDirection(hidden_name, (k, 1))) + if direction == one_direction: + hidden += (1 - large_number * value_thresholds[k - 1]) * vec0 + hidden -= large_number * value_thresholds[k - 1] * vec1 + else: + hidden += large_number * vec0 + large_number * vec1 + return hidden + + return DiscretisingLayerMaterials( + action=action, hidden_space=hidden_space, output_values=output_values) + + +def map_numerical_mlp( + f: Callable[[float], float], + input_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + input_value_set: Iterable[float], + one_space: bases.VectorSpaceWithBasis, + large_number: float = 100, + hidden_name: bases.Name = "__hidden__", +) -> transformers.MLP: + """Returns an MLP that encodes any function of a single variable f(x). + + This is implemented by discretising the input according to input_value_set + and defining thresholds that determine which part of the input range will + is allocated to which value in input_value_set. + + elements of value set: v0 | v1 | v2 | v3 | v4 | ... + thresholds: t0 t1 t2 t3 t4 + + The MLP computes two hidden activations per threshold: + hidden_k_0 = ReLU(L * (x - threshold[k]) + 1) + hidden_k_1 = ReLU(L * (x - threshold[k])) + + Note that hidden_k_1 - hidden_k_2 is: + 1 if x >= threshold[k] + 1/L + 0 if x <= threshold[k] + between 0 and 1 if threshold[k] < x < threshold[k] + 1/L + + So as long as we choose L a big enough number, we have + hidden_k_0 - hidden_k_1 = 1 if x >= threshold[k]. + + The MLP then computes the output as: + output = f(input[0]) + + sum((hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1])) + for all k=0,1,...) + + This sum will be (by a telescoping sums argument) + f(input[0]) if x <= threshold[0] + f(input[k]) if threshold[k-1] < x <= threshold[k] for some other k + f(input[-1]) if x > threshold[-1] + which approximates f() up to an accuracy given by input_value_set and L. + + Args: + f: Function to approximate. + input_space: 1-d vector space that encodes the input x. + output_space: 1-d vector space to write the output to. + input_value_set: Set of values the input can take. + one_space: Auxiliary 1-d vector space that must contain 1 in the input. + large_number: Large number L that determines accuracy of the computation. + Note that too large values of L can lead to numerical issues, particularly + during inference on GPU/TPU. + hidden_name: Name for hidden dimensions. + """ + bases.ensure_dims(input_space, num_dims=1, name="input_space") + bases.ensure_dims(output_space, num_dims=1, name="output_space") + bases.ensure_dims(one_space, num_dims=1, name="one_space") + + input_space = bases.join_vector_spaces(input_space, one_space) + out_vec = output_space.vector_from_basis_direction(output_space.basis[0]) + + discretising_layer = _get_discretising_layer( + input_value_set=input_value_set, + f=f, + hidden_name=hidden_name, + one_direction=one_space.basis[0], + large_number=large_number) + first_layer = vectorspace_fns.Linear.from_action( + input_space, discretising_layer.hidden_space, discretising_layer.action) + + def second_layer_action( + direction: bases.BasisDirection) -> bases.VectorInBasis: + # output = sum( + # (hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1])) + # for all k) + if direction.name == f"{hidden_name}start": + return discretising_layer.output_values[0] * out_vec + k, i = direction.value + # add hidden_k_0 and subtract hidden_k_1 + sign = {0: 1, 1: -1}[i] + return sign * (discretising_layer.output_values[k] - + discretising_layer.output_values[k - 1]) * out_vec + + second_layer = vectorspace_fns.Linear.from_action( + discretising_layer.hidden_space, output_space, second_layer_action) + + return transformers.MLP(first_layer, second_layer) + + +def map_numerical_to_categorical_mlp( + f: Callable[[float], float], + input_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + input_value_set: Iterable[float], + one_space: bases.VectorSpaceWithBasis, + large_number: float = 100, + hidden_name: bases.Name = "__hidden__", +) -> transformers.MLP: + """Returns an MLP to compute f(x) from a numerical to a categorical variable. + + Uses a set of possible output values, and rounds f(x) to the closest value + in this set to create a categorical output variable. + + The output is discretised the same way as in `map_numerical_mlp`. + + Args: + f: Function to approximate. + input_space: 1-d vector space that encodes the input x. + output_space: n-d vector space to write categorical output to. The output + directions need to encode the possible output values. + input_value_set: Set of values the input can take. + one_space: Auxiliary 1-d space that must contain 1 in the input. + large_number: Large number L that determines accuracy of the computation. + hidden_name: Name for hidden dimensions. + """ + bases.ensure_dims(input_space, num_dims=1, name="input_space") + bases.ensure_dims(one_space, num_dims=1, name="one_space") + + input_space = bases.join_vector_spaces(input_space, one_space) + + vec_by_out_val = dict() + for d in output_space.basis: + # TODO(b/255937603): Do a similar assert in other places where we expect + # categorical basis directions to encode values. + assert d.value is not None, ("output directions need to encode " + "possible output values") + vec_by_out_val[d.value] = output_space.vector_from_basis_direction(d) + + discretising_layer = _get_discretising_layer( + input_value_set=input_value_set, + f=f, + hidden_name=hidden_name, + one_direction=one_space.basis[0], + large_number=large_number) + + assert set(discretising_layer.output_values).issubset( + set(vec_by_out_val.keys())) + + first_layer = vectorspace_fns.Linear.from_action( + input_space, discretising_layer.hidden_space, discretising_layer.action) + + def second_layer_action( + direction: bases.BasisDirection) -> bases.VectorInBasis: + """Computes output value and returns corresponding output direction.""" + if direction.name == f"{hidden_name}start": + return vec_by_out_val[discretising_layer.output_values[0]] + else: + k, i = direction.value + # add hidden_k_0 and subtract hidden_k_1 + sign = {0: 1, 1: -1}[i] + out_k = discretising_layer.output_values[k] + out_k_m_1 = discretising_layer.output_values[k - 1] + return sign * (vec_by_out_val[out_k] - vec_by_out_val[out_k_m_1]) + + second_layer = vectorspace_fns.Linear.from_action( + discretising_layer.hidden_space, output_space, second_layer_action) + + return transformers.MLP(first_layer, second_layer) + + +def linear_sequence_map_numerical_mlp( + input1_basis_direction: bases.BasisDirection, + input2_basis_direction: bases.BasisDirection, + output_basis_direction: bases.BasisDirection, + input1_factor: float, + input2_factor: float, + hidden_name: bases.Name = "__hidden__", +) -> transformers.MLP: + """Returns an MLP that encodes a linear function f(x, y) = a*x + b*y. + + Args: + input1_basis_direction: Basis direction that encodes the input x. + input2_basis_direction: Basis direction that encodes the input y. + output_basis_direction: Basis direction to write the output to. + input1_factor: Linear factor a for input x. + input2_factor: Linear factor a for input y. + hidden_name: Name for hidden dimensions. + """ + input_space = bases.VectorSpaceWithBasis( + [input1_basis_direction, input2_basis_direction]) + output_space = bases.VectorSpaceWithBasis([output_basis_direction]) + out_vec = output_space.vector_from_basis_direction(output_basis_direction) + + hidden_directions = [ + bases.BasisDirection(f"{hidden_name}x", 1), + bases.BasisDirection(f"{hidden_name}x", -1), + bases.BasisDirection(f"{hidden_name}y", 1), + bases.BasisDirection(f"{hidden_name}y", -1) + ] + hidden_space = bases.VectorSpaceWithBasis(hidden_directions) + x_pos_vec, x_neg_vec, y_pos_vec, y_neg_vec = ( + hidden_space.vector_from_basis_direction(d) for d in hidden_directions) + + def first_layer_action( + direction: bases.BasisDirection) -> bases.VectorInBasis: + output = hidden_space.null_vector() + if direction == input1_basis_direction: + output += x_pos_vec - x_neg_vec + if direction == input2_basis_direction: + output += y_pos_vec - y_neg_vec + return output + + first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space, + first_layer_action) + + def second_layer_action( + direction: bases.BasisDirection) -> bases.VectorInBasis: + if direction.name == f"{hidden_name}x": + return input1_factor * direction.value * out_vec + if direction.name == f"{hidden_name}y": + return input2_factor * direction.value * out_vec + return output_space.null_vector() + + second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space, + second_layer_action) + + return transformers.MLP(first_layer, second_layer) diff --git a/craft/chamber/numerical_mlp_test.py b/craft/chamber/numerical_mlp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2cdf3a20f5b9af68cb90090a7045e41081cb08cf --- /dev/null +++ b/craft/chamber/numerical_mlp_test.py @@ -0,0 +1,233 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for chamber.numerical_mlp.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.craft.chamber import numerical_mlp +from tracr.utils import errors + + +class NumericalMlpTest(tests_common.VectorFnTestCase): + + @parameterized.parameters([ + dict( + in_value_set={-2, -2, -1, 0, 1, 2, 3}, + x=2, + function=lambda x: x, + result=2), + dict( + in_value_set={-2, -2, -1, 0, 1, 2, 3}, + x=2, + function=lambda x: x**2, + result=4), + dict( + in_value_set={-2, -2, -1, 0, 1, 2, 3}, + x=2, + function=lambda x: x**3, + result=8), + dict( + in_value_set={-2, -2, -1, 0, 1, 2, 3}, + x=-2, + function=lambda x: x, + result=-2), + dict( + in_value_set={-2, -2, -1, 0, 1, 2, 3}, + x=-2, + function=lambda x: x**2, + result=4), + dict( + in_value_set={-2, -2, -1, 0, 1, 2, 3}, + x=-2, + function=lambda x: x**3, + result=-8), + ]) + def test_map_numerical_mlp_produces_expected_outcome(self, in_value_set, x, + function, result): + + input_dir = bases.BasisDirection("input") + output_dir = bases.BasisDirection("output") + one_dir = bases.BasisDirection("one") + input_space = bases.VectorSpaceWithBasis([input_dir]) + output_space = bases.VectorSpaceWithBasis([output_dir]) + one_space = bases.VectorSpaceWithBasis([one_dir]) + + mlp = numerical_mlp.map_numerical_mlp( + f=function, + input_space=input_space, + output_space=output_space, + one_space=one_space, + input_value_set=in_value_set, + ) + + test_inputs = bases.VectorInBasis( + basis_directions=[input_dir, output_dir, one_dir], + magnitudes=np.array([x, 0, 1])) + + expected_results = bases.VectorInBasis( + basis_directions=[input_dir, output_dir, one_dir], + magnitudes=np.array([0, result, 0])) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + @parameterized.parameters([ + dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1), + dict( + in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5), + dict( + in_value_set={0, 1, 2, 3}, + x=3, + function=lambda x: 1 / x, + result=1 / 3), + ]) + def test_map_numerical_mlp_logs_warning_and_produces_expected_outcome( + self, in_value_set, x, function, result): + + input_dir = bases.BasisDirection("input") + output_dir = bases.BasisDirection("output") + one_dir = bases.BasisDirection("one") + input_space = bases.VectorSpaceWithBasis([input_dir]) + output_space = bases.VectorSpaceWithBasis([output_dir]) + one_space = bases.VectorSpaceWithBasis([one_dir]) + + with self.assertLogs(level="WARNING"): + mlp = numerical_mlp.map_numerical_mlp( + f=function, + input_space=input_space, + output_space=output_space, + one_space=one_space, + input_value_set=in_value_set, + ) + + test_inputs = bases.VectorInBasis( + basis_directions=[input_dir, output_dir, one_dir], + magnitudes=np.array([x, 0, 1])) + + expected_results = bases.VectorInBasis( + basis_directions=[input_dir, output_dir, one_dir], + magnitudes=np.array([0, result, 0])) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + @parameterized.parameters([ + dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1), + dict( + in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5), + dict( + in_value_set={0, 1, 2, 3}, + x=3, + function=lambda x: 1 / x, + result=1 / 3), + ]) + def test_map_numerical_to_categorical_mlp_logs_warning_and_produces_expected_outcome( + self, in_value_set, x, function, result): + + f_ign = errors.ignoring_arithmetic_errors(function) + out_value_set = {f_ign(x) for x in in_value_set if f_ign(x) is not None} + + in_space = bases.VectorSpaceWithBasis.from_names(["input"]) + out_space = bases.VectorSpaceWithBasis.from_values("output", out_value_set) + one_space = bases.VectorSpaceWithBasis.from_names(["one"]) + + residual_space = bases.join_vector_spaces(in_space, one_space, out_space) + in_vec = residual_space.vector_from_basis_direction(in_space.basis[0]) + one_vec = residual_space.vector_from_basis_direction(one_space.basis[0]) + + with self.assertLogs(level="WARNING"): + mlp = numerical_mlp.map_numerical_to_categorical_mlp( + f=function, + input_space=in_space, + output_space=out_space, + input_value_set=in_value_set, + one_space=one_space, + ) + + test_inputs = x * in_vec + one_vec + expected_results = out_space.vector_from_basis_direction( + bases.BasisDirection("output", result)) + test_outputs = mlp.apply(test_inputs).project(out_space) + self.assertVectorAllClose(test_outputs, expected_results) + + @parameterized.parameters([ + dict(x_factor=1, y_factor=2, x=1, y=1, result=3), + dict(x_factor=1, y_factor=2, x=1, y=-1, result=-1), + dict(x_factor=1, y_factor=-1, x=1, y=1, result=0), + dict(x_factor=1, y_factor=1, x=3, y=5, result=8), + dict(x_factor=-2, y_factor=-0.5, x=4, y=1, result=-8.5), + ]) + def test_linear_sequence_map_produces_expected_result(self, x_factor, + y_factor, x, y, result): + + input1_dir = bases.BasisDirection("input1") + input2_dir = bases.BasisDirection("input2") + output_dir = bases.BasisDirection("output") + + mlp = numerical_mlp.linear_sequence_map_numerical_mlp( + input1_basis_direction=input1_dir, + input2_basis_direction=input2_dir, + output_basis_direction=output_dir, + input1_factor=x_factor, + input2_factor=y_factor) + + test_inputs = bases.VectorInBasis( + basis_directions=[input1_dir, input2_dir, output_dir], + magnitudes=np.array([x, y, 0])) + + expected_results = bases.VectorInBasis( + basis_directions=[input1_dir, input2_dir, output_dir], + magnitudes=np.array([0, 0, result])) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + @parameterized.parameters([ + dict(x_factor=1, y_factor=2, x=1, result=3), + dict(x_factor=1, y_factor=-1, x=1, result=0), + ]) + def test_linear_sequence_map_produces_expected_result_with_same_inputs( + self, x_factor, y_factor, x, result): + + input_dir = bases.BasisDirection("input") + output_dir = bases.BasisDirection("output") + + mlp = numerical_mlp.linear_sequence_map_numerical_mlp( + input1_basis_direction=input_dir, + input2_basis_direction=input_dir, + output_basis_direction=output_dir, + input1_factor=x_factor, + input2_factor=y_factor) + + test_inputs = bases.VectorInBasis( + basis_directions=[input_dir, output_dir], magnitudes=np.array([x, 0])) + + expected_results = bases.VectorInBasis( + basis_directions=[input_dir, output_dir], + magnitudes=np.array([0, result])) + + test_outputs = mlp.apply(test_inputs) + + self.assertVectorAllClose(test_outputs, expected_results) + + +if __name__ == "__main__": + absltest.main() diff --git a/craft/chamber/selector_width.py b/craft/chamber/selector_width.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c24184a1e4c17393a2869925132de3b368e2d8 --- /dev/null +++ b/craft/chamber/selector_width.py @@ -0,0 +1,144 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SelectorWidth component consisting of an attention head and an MLP.""" + +from typing import Iterable +from tracr.craft import bases +from tracr.craft import transformers +from tracr.craft import vectorspace_fns +from tracr.craft.chamber import categorical_attn +from tracr.craft.chamber import numerical_mlp + + +def selector_width( + query_space: bases.VectorSpaceWithBasis, + key_space: bases.VectorSpaceWithBasis, + output_space: bases.VectorSpaceWithBasis, + bos_space: bases.VectorSpaceWithBasis, + one_space: bases.VectorSpaceWithBasis, + attn_fn: categorical_attn.QueryKeyToAttnLogit, + out_value_set: Iterable[float], + categorical_output: bool, + causal: bool = False, + softmax_coldness: float = 100., + mlp_large_number: float = 100., + label: str = "", +) -> transformers.SeriesWithResiduals: + """Returns a craft block implementing RASP's SelectorWidth primitive. + + The block consists of one attention head and one MLP. + + The attention head implements the attention pattern (attn_fn or key=bos) and + aggregates the bos dimension over this pattern. The output of this will be + 1/(d+1) in every position, where d is the "width" of the attention pattern, + i.e. the number of 1s in a row. + + The MLP then computes d from the previous output in all positions except for + the first BOS position. In the BOS position the MLP removes the output of the + attention head, to ensure it only contains the encoding of the BOS token + which is expected by all other model components. + + Args: + query_space: Vector space containing (categorical) query input. + key_space: Vector space containing (categorical) key input. + output_space: Vector space which will contain (numerical or categorical) + output. + bos_space: 1-d space used to identify the beginning of sequence token. + one_space: Auxiliary 1-d vector space that must contain 1 in the input. + attn_fn: A selector function f(query, key) operating on the query/key basis + directions that defines the attention pattern to compute the width of. + out_value_set: Set of possible output values of this SelectorWidth. + categorical_output: If True, encode the output as a categorical variable. + causal: If True, use masked attention. + softmax_coldness: The inverse temperature of the softmax. Default value is + high which makes the attention close to a hard maximum. + mlp_large_number: A larger number makes the MLP more accurate. + label: A name for this block, used to label auxiliary dimensions. + """ + assert output_space.num_dims == 1 or categorical_output + + attn_out_dir = bases.BasisDirection(f"{label}_selector_width_attn_output") + attn_out_space = bases.VectorSpaceWithBasis([attn_out_dir]) + attn_out_vec = attn_out_space.vector_from_basis_direction(attn_out_dir) + + attn = categorical_attn.categorical_attn( + query_space=query_space, + key_space=key_space, + value_space=bos_space, + output_space=attn_out_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=attn_fn, + default_output=attn_out_space.null_vector(), + causal=causal, + always_attend_to_bos=True, + use_bos_for_default_output=False, + softmax_coldness=softmax_coldness) + + fun = lambda x: (1 / x) - 1 + in_value_set = {1 / (x + 1) for x in out_value_set} + if categorical_output: + mlp = numerical_mlp.map_numerical_to_categorical_mlp( + f=fun, + input_space=attn_out_space, + output_space=output_space, + input_value_set=in_value_set, + one_space=one_space, + hidden_name=f"_hidden_{label}_", + large_number=mlp_large_number) + else: + mlp = numerical_mlp.map_numerical_mlp( + f=fun, + input_space=attn_out_space, + output_space=output_space, + input_value_set=in_value_set, + one_space=one_space, + hidden_name=f"_hidden_{label}_", + large_number=mlp_large_number) + + # This implementation of selector width writes at each position including + # the BOS. To ensure that the BOS token position does not contain + # additional values, we add an mlp to subtract the output of both layers. + clean_bos_out_space = bases.join_vector_spaces(attn_out_space, output_space) + vec_to_subtract_from_bos = attn_out_vec.project(clean_bos_out_space) + + if categorical_output: + # Add the one-hot encoding of the zero value to the vector + # which will get scrubbed from the BOS position. + zero_dir = [d for d in output_space.basis if d.value == 0][0] + zero_vec = clean_bos_out_space.vector_from_basis_direction(zero_dir) + vec_to_subtract_from_bos += zero_vec + + # Construct an MLP that subtracts vec_to_subtract_from_bos * bos + # from the residual stream which is vec_to_subtract_from_bos in the + # bos position and 0 else. vec_to_subtract_from_bos contains what the + # attention head writes to the bos position. + + hidden_dir = bases.BasisDirection("_hidden_clean_bos_") + hidden_space = bases.VectorSpaceWithBasis([hidden_dir]) + hidden_vec = hidden_space.vector_from_basis_direction(hidden_dir) + + # It's okay to use the local variables because they are only used within + # the same loop iteration to create the MLP. + # pylint: disable=cell-var-from-loop + first_layer = vectorspace_fns.Linear.from_action(bos_space, hidden_space, + lambda x: hidden_vec) + second_layer = vectorspace_fns.Linear.from_action( + hidden_space, clean_bos_out_space, lambda x: -vec_to_subtract_from_bos) + # pylint: enable=cell-var-from-loop + clean_bos_mlp = transformers.MLP(first_layer, second_layer) + + mlp = transformers.MLP.combine_in_parallel([mlp, clean_bos_mlp]) + return transformers.SeriesWithResiduals([attn, mlp]) diff --git a/craft/chamber/selector_width_test.py b/craft/chamber/selector_width_test.py new file mode 100644 index 0000000000000000000000000000000000000000..77f449de95c34b625c9b6e24a16c14c15890ff66 --- /dev/null +++ b/craft/chamber/selector_width_test.py @@ -0,0 +1,155 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for selector_width.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.craft.chamber import selector_width + + +class SelectorWidthTest(tests_common.VectorFnTestCase): + + @parameterized.product( + causal=[False, True], + categorical_output=[False, True], + input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]]) + def test_selector_width_of_select_all_is_length(self, causal, + categorical_output, + input_seq): + vocab = range(-20, 20) + input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) + + if categorical_output: + output_space = bases.VectorSpaceWithBasis.from_values("output", range(10)) + else: + output_space = bases.VectorSpaceWithBasis( + [bases.BasisDirection("output")]) + + bos_dir = bases.BasisDirection("bos_dimension") + bos_space = bases.VectorSpaceWithBasis([bos_dir]) + + one_dir = bases.BasisDirection("one_dimension") + one_space = bases.VectorSpaceWithBasis([one_dir]) + + input_space = bases.join_vector_spaces(input_space, bos_space, one_space) + residual_space = bases.join_vector_spaces(input_space, output_space) + bos_vec = residual_space.vector_from_basis_direction(bos_dir) + one_vec = residual_space.vector_from_basis_direction(one_dir) + + block = selector_width.selector_width( + query_space=input_space, + key_space=input_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=lambda x, y: True, + out_value_set=set(range(len(input_seq) + 1)), + categorical_output=categorical_output, + causal=causal, + label="select_all") + + test_inputs = [bos_vec + one_vec] + for x in input_seq: + test_inputs.append( + residual_space.vector_from_basis_direction( + bases.BasisDirection("input", x)) + one_vec) + test_inputs = bases.VectorInBasis.stack(test_inputs) + + # Expect length of the input sequence + if causal: + expected_results = list(range(1, len(input_seq) + 1)) + else: + expected_results = [len(input_seq) for _ in input_seq] + + if categorical_output: + expected_results = [ + output_space.vector_from_basis_direction( + bases.BasisDirection("output", x)) for x in expected_results + ] + else: + output_vec = output_space.vector_from_basis_direction( + bases.BasisDirection("output")) + expected_results = [x * output_vec for x in expected_results] + + expected_results = bases.VectorInBasis.stack(expected_results) + + test_outputs = block.apply(test_inputs).project(output_space) + self.assertVectorAllClose( + tests_common.strip_bos_token(test_outputs), expected_results) + + @parameterized.product( + causal=[False, True], + categorical_output=[False, True], + input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]]) + def test_selector_width_of_select_none_is_zero(self, causal, + categorical_output, input_seq): + vocab = range(-20, 20) + input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) + + if categorical_output: + output_space = bases.VectorSpaceWithBasis.from_values("output", range(10)) + else: + output_space = bases.VectorSpaceWithBasis( + [bases.BasisDirection("output")]) + + bos_dir = bases.BasisDirection("bos_dimension") + bos_space = bases.VectorSpaceWithBasis([bos_dir]) + + one_dir = bases.BasisDirection("one_dimension") + one_space = bases.VectorSpaceWithBasis([one_dir]) + + input_space = bases.join_vector_spaces(input_space, bos_space, one_space) + residual_space = bases.join_vector_spaces(input_space, output_space) + bos_vec = residual_space.vector_from_basis_direction(bos_dir) + one_vec = residual_space.vector_from_basis_direction(one_dir) + + block = selector_width.selector_width( + query_space=input_space, + key_space=input_space, + output_space=output_space, + bos_space=bos_space, + one_space=one_space, + attn_fn=lambda x, y: False, + out_value_set=set(range(len(input_seq) + 1)), + categorical_output=categorical_output, + causal=causal, + label="select_all") + + test_inputs = [bos_vec + one_vec] + for x in input_seq: + test_inputs.append( + residual_space.vector_from_basis_direction( + bases.BasisDirection("input", x)) + one_vec) + test_inputs = bases.VectorInBasis.stack(test_inputs) + + # Expect zero output + if categorical_output: + expected_results = [ + output_space.vector_from_basis_direction( + bases.BasisDirection("output", 0)) for _ in input_seq + ] + else: + expected_results = [output_space.null_vector() for _ in input_seq] + expected_results = bases.VectorInBasis.stack(expected_results) + + test_outputs = block.apply(test_inputs).project(output_space) + self.assertVectorAllClose( + tests_common.strip_bos_token(test_outputs), expected_results) + + +if __name__ == "__main__": + absltest.main() diff --git a/craft/tests_common.py b/craft/tests_common.py new file mode 100644 index 0000000000000000000000000000000000000000..03a543352f7b01d3787f0958e064fa01dfa041c7 --- /dev/null +++ b/craft/tests_common.py @@ -0,0 +1,33 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions for tests.""" + +from absl.testing import parameterized +import numpy as np +from tracr.craft import bases + + +def strip_bos_token(vector: bases.VectorInBasis) -> bases.VectorInBasis: + """Removes BOS token of a vector.""" + return bases.VectorInBasis(vector.basis_directions, vector.magnitudes[1:]) + + +class VectorFnTestCase(parameterized.TestCase): + """Asserts for vectors.""" + + def assertVectorAllClose(self, v1: bases.VectorInBasis, + v2: bases.VectorInBasis): + self.assertEqual(v1.basis_directions, v2.basis_directions) + np.testing.assert_allclose(v1.magnitudes, v2.magnitudes, atol=1e-7) diff --git a/craft/transformers.py b/craft/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..87a1b6ea756a770e83eafd7e422cfdb8c3aa640f --- /dev/null +++ b/craft/transformers.py @@ -0,0 +1,197 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pieces for making transformers.""" + +import abc +import dataclasses +from typing import Iterable, Optional, Sequence, Union + +import numpy as np + +from tracr.craft import bases +from tracr.craft import vectorspace_fns + +project = vectorspace_fns.project + + +def _np_softmax(x, axis=-1): + x_max = np.max(x, axis=axis, keepdims=True) + return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True) + + +def _np_relu(x): + return np.where(x > 0, x, 0) + + +def relu(x: bases.VectorInBasis) -> bases.VectorInBasis: + return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes)) + + +class Block(abc.ABC): + """Transformer block, acting on a sequence of vector space elements. + + Attributes: + residual_space: Vector space that contains all subspaces the Block interacts + with. This can be either the full residual space of a model or a subspace. + """ + residual_space: bases.VectorSpaceWithBasis + + @abc.abstractmethod + def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: + """Applies self to an input.""" + + +@dataclasses.dataclass +class AttentionHead(Block): + """A transformer attention head.""" + w_qk: vectorspace_fns.ScalarBilinear + w_ov: vectorspace_fns.Linear + residual_space: Optional[bases.VectorSpaceWithBasis] = None + causal: bool = False + + def __post_init__(self): + """Infer residual stream and typecheck subspaces.""" + if self.residual_space is None: + self.residual_space = bases.join_vector_spaces(self.w_qk.left_space, + self.w_qk.right_space, + self.w_ov.input_space, + self.w_ov.output_space) + + assert self.w_qk.left_space.issubspace(self.residual_space) + assert self.w_qk.right_space.issubspace(self.residual_space) + assert self.w_ov.input_space.issubspace(self.residual_space) + assert self.w_ov.output_space.issubspace(self.residual_space) + + def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: + assert x in self.residual_space + # seq_len x query_space + queries = x.project(self.w_qk.left_space) + # seq_len x key_space + keys = x.project(self.w_qk.right_space) + + attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T + + if self.causal: + # The 1 gives us the matrix above the diagonal. + mask = np.triu(np.full_like(attn_matrix, -np.inf), 1) + attn_matrix = attn_matrix + mask + + attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to + values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model + + magnitudes = attn_weights @ values # seq_len_from, d_model + return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes) + + def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis: + """Wov but acting on the residual space.""" + x = project(self.residual_space, self.w_ov.input_space)(x) + out = self.w_ov(x) + return project(self.w_ov.output_space, self.residual_space)(out) + + @property + def num_heads(self) -> int: + return 1 + + def as_multi(self) -> "MultiAttentionHead": + return MultiAttentionHead([self]) + + +@dataclasses.dataclass +class MultiAttentionHead(Block): + """Applies attention heads in parallel.""" + sub_blocks: list[Union[AttentionHead, "MultiAttentionHead"]] + + def __post_init__(self): + spaces = [block.residual_space for block in self.sub_blocks] + self.residual_space, *others = spaces + assert all(s == self.residual_space for s in others) + + def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: + # each element is seq_len x embedding + outs = [block.apply(x) for block in self.sub_blocks] + return bases.VectorInBasis.sum(outs) # seq_len x embedding + + @property + def num_heads(self) -> int: + return sum(sub_block.num_heads for sub_block in self.sub_blocks) + + def heads(self) -> Iterable[AttentionHead]: + for sub_block in self.sub_blocks: + if isinstance(sub_block, AttentionHead): + yield sub_block + elif isinstance(sub_block, MultiAttentionHead): + yield from sub_block.heads() + else: + raise NotImplementedError() + + def as_multi(self) -> "MultiAttentionHead": + return self + + +@dataclasses.dataclass +class MLP(Block): + """A transformer MLP block.""" + fst: vectorspace_fns.Linear + snd: vectorspace_fns.Linear + residual_space: Optional[bases.VectorSpaceWithBasis] = None + + def __post_init__(self): + """Typecheck subspaces.""" + if self.residual_space is None: + self.residual_space = bases.join_vector_spaces(self.fst.input_space, + self.snd.output_space) + + assert self.fst.output_space == self.snd.input_space + assert self.fst.input_space.issubspace(self.residual_space) + assert self.snd.output_space.issubspace(self.residual_space) + + def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: + assert x in self.residual_space + + x = project(self.residual_space, self.fst.input_space)(x) + hidden = self.fst(x) + hidden = relu(hidden) + out = self.snd(hidden) + return project(self.snd.output_space, self.residual_space)(out) + + @classmethod + def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP": + fst = vectorspace_fns.Linear.combine_in_parallel( + [block.fst for block in mlps]) + snd = vectorspace_fns.Linear.combine_in_parallel( + [block.snd for block in mlps]) + return cls(fst=fst, snd=snd, residual_space=None) + + +# Block that fits into a half-layer, without residual connections. +HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead] + + +@dataclasses.dataclass +class SeriesWithResiduals(Block): + """A series of blocks with residual connections.""" + blocks: list[HalfLayerBlock] + + def __post_init__(self): + spaces = [block.residual_space for block in self.blocks] + self.residual_space = bases.join_vector_spaces(*spaces) + + def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: + x = x.project(self.residual_space) + for block in self.blocks: + x_in = x.project(block.residual_space) + x_out = block.apply(x_in).project(self.residual_space) + x = x + x_out + return x diff --git a/craft/transformers_test.py b/craft/transformers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a854bb473b565540629131634f413b51a7a570c --- /dev/null +++ b/craft/transformers_test.py @@ -0,0 +1,160 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transformers.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.craft import transformers +from tracr.craft import vectorspace_fns as vs_fns + +# This makes it easier to use comments to annotate dimensions in arrays +# pylint: disable=g-no-space-after-comment + + +class AttentionHeadTest(tests_common.VectorFnTestCase): + + @parameterized.parameters([ + dict(with_residual_stream=False), + dict(with_residual_stream=True), + ]) + def test_attention_head(self, with_residual_stream): + i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) + o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) + q = bases.VectorSpaceWithBasis.from_values("q", [1, 2]) + k = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) + rs = bases.direct_sum(i, o, q, k) + + seq = bases.VectorInBasis( + rs.basis, + np.array([ + #i1 i2 o1 o2 q1 q2 p1 p2 + [1, 0, 0, 0, 1, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 1], + ])) + + head = transformers.AttentionHead( + w_qk=vs_fns.ScalarBilinear(q, k, + np.eye(2) * 100), + w_ov=vs_fns.Linear(i, o, np.eye(2)), + residual_space=rs if with_residual_stream else None, + causal=False, + ) + + self.assertVectorAllClose( + head.apply(seq), + bases.VectorInBasis( + rs.basis, + np.array([ + #i1 i2 o1 o2 q1 q2 p1 p2 + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + ])), + ) + + +class MLPTest(tests_common.VectorFnTestCase): + + @parameterized.parameters([ + dict(with_residual_stream=False, same_in_out=False), + dict(with_residual_stream=False, same_in_out=True), + dict(with_residual_stream=True, same_in_out=False), + dict(with_residual_stream=True, same_in_out=True), + ]) + def test_mlp(self, with_residual_stream, same_in_out): + i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) + if same_in_out: + o, rs = i, i + expected_result = np.array([ + #o1 o2 + [1, 0], + [0, 1], + ]) + else: + o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) + rs = bases.direct_sum(i, o) + expected_result = np.array([ + #i1 i2 o1 o2 + [0, 0, 1, 0], + [0, 0, 0, 1], + ]) + h = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) + + seq = bases.VectorInBasis( + i.basis, + np.array([ + #i1 i2 + [1, -1], + [-1, 1], + ])).project(rs) + + mlp = transformers.MLP( + fst=vs_fns.Linear(i, h, np.eye(2)), + snd=vs_fns.Linear(h, o, np.eye(2)), + residual_space=rs if with_residual_stream else None, + ) + + self.assertEqual( + mlp.apply(seq), + bases.VectorInBasis(rs.basis, expected_result), + ) + + def test_combining_mlps(self): + in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2]) + in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4]) + out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2]) + residual_space = bases.join_vector_spaces(in12, in34, out12) + + h1 = bases.VectorSpaceWithBasis.from_values("h", [1]) + h2 = bases.VectorSpaceWithBasis.from_values("h", [2]) + + # MLP1 maps in2 -> h1 -> out1 + mlp1 = transformers.MLP( + fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])), + snd=vs_fns.Linear(h1, out12, np.array([[1, 0]]))) + + # MLP2 maps in3 -> h2 -> out2 + mlp2 = transformers.MLP( + fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])), + snd=vs_fns.Linear(h2, out12, np.array([[0, 1]]))) + + mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2]) + + seq = bases.VectorInBasis( + bases.direct_sum(in12, in34).basis, + np.array([ + #i1 i2 i3 i4 + [1, 2, 0, 0], + [0, 2, 3, 4], + ])).project(residual_space) + + expected_result = bases.VectorInBasis( + out12.basis, + np.array([ + #o1 o2 + [2, 0], + [2, 3], + ])) + + self.assertEqual( + mlp.apply(seq).project(out12), + expected_result, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/craft/vectorspace_fns.py b/craft/vectorspace_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..c58ac4e474258dd9d938d9b1332d6fddc8e574f4 --- /dev/null +++ b/craft/vectorspace_fns.py @@ -0,0 +1,162 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions on vector spaces.""" + +import abc +import dataclasses +from typing import Callable, Sequence + +import numpy as np + +from tracr.craft import bases + +VectorSpaceWithBasis = bases.VectorSpaceWithBasis +VectorInBasis = bases.VectorInBasis +BasisDirection = bases.BasisDirection + + +class VectorFunction(abc.ABC): + """A function that acts on vectors.""" + + input_space: VectorSpaceWithBasis + output_space: VectorSpaceWithBasis + + @abc.abstractmethod + def __call__(self, x: VectorInBasis) -> VectorInBasis: + """Evaluates the function.""" + + +class Linear(VectorFunction): + """A linear function.""" + + def __init__( + self, + input_space: VectorSpaceWithBasis, + output_space: VectorSpaceWithBasis, + matrix: np.ndarray, + ): + """Initialises. + + Args: + input_space: The input vector space. + output_space: The output vector space. + matrix: a [input, output] matrix acting in a (sorted) basis. + """ + self.input_space = input_space + self.output_space = output_space + self.matrix = matrix + + def __post_init__(self) -> None: + output_size, input_size = self.matrix.shape + assert input_size == self.input_space.num_dims + assert output_size == self.output_space.num_dims + + def __call__(self, x: VectorInBasis) -> VectorInBasis: + if x not in self.input_space: + raise TypeError(f"{x=} not in {self.input_space=}.") + return VectorInBasis( + basis_directions=sorted(self.output_space.basis), + magnitudes=x.magnitudes @ self.matrix, + ) + + @classmethod + def from_action( + cls, + input_space: VectorSpaceWithBasis, + output_space: VectorSpaceWithBasis, + action: Callable[[BasisDirection], VectorInBasis], + ) -> "Linear": + """from_action(i, o)(action) creates a Linear.""" + + matrix = np.zeros((input_space.num_dims, output_space.num_dims)) + for i, direction in enumerate(input_space.basis): + out_vector = action(direction) + if out_vector not in output_space: + raise TypeError(f"image of {direction} from {input_space=} " + f"is not in {output_space=}") + matrix[i, :] = out_vector.magnitudes + + return Linear(input_space, output_space, matrix) + + @classmethod + def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear": + """Combines multiple parallel linear functions into a single one.""" + joint_input_space = bases.join_vector_spaces( + *[fn.input_space for fn in fns]) + joint_output_space = bases.join_vector_spaces( + *[fn.output_space for fn in fns]) + + def action(x: bases.BasisDirection) -> bases.VectorInBasis: + out = joint_output_space.null_vector() + for fn in fns: + if x in fn.input_space: + x_vec = fn.input_space.vector_from_basis_direction(x) + out += fn(x_vec).project(joint_output_space) + return out + + return cls.from_action(joint_input_space, joint_output_space, action) + + +def project( + from_space: VectorSpaceWithBasis, + to_space: VectorSpaceWithBasis, +) -> Linear: + """Creates a projection.""" + + def action(direction: bases.BasisDirection) -> VectorInBasis: + if direction in to_space: + return to_space.vector_from_basis_direction(direction) + else: + return to_space.null_vector() + + return Linear.from_action(from_space, to_space, action=action) + + +@dataclasses.dataclass +class ScalarBilinear: + """A scalar-valued bilinear operator.""" + left_space: VectorSpaceWithBasis + right_space: VectorSpaceWithBasis + matrix: np.ndarray + + def __post_init__(self): + """Ensure matrix acts in sorted bases and typecheck sizes.""" + left_size, right_size = self.matrix.shape + assert left_size == self.left_space.num_dims + assert right_size == self.right_space.num_dims + + def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float: + """Describes the action of the operator on vectors.""" + if x not in self.left_space: + raise TypeError(f"{x=} not in {self.left_space=}.") + if y not in self.right_space: + raise TypeError(f"{y=} not in {self.right_space=}.") + return (x.magnitudes.T @ self.matrix @ y.magnitudes).item() + + @classmethod + def from_action( + cls, + left_space: VectorSpaceWithBasis, + right_space: VectorSpaceWithBasis, + action: Callable[[BasisDirection, BasisDirection], float], + ) -> "ScalarBilinear": + """from_action(l, r)(action) creates a ScalarBilinear.""" + + matrix = np.zeros((left_space.num_dims, right_space.num_dims)) + for i, left_direction in enumerate(left_space.basis): + for j, right_direction in enumerate(right_space.basis): + matrix[i, j] = action(left_direction, right_direction) + + return ScalarBilinear(left_space, right_space, matrix) diff --git a/craft/vectorspace_fns_test.py b/craft/vectorspace_fns_test.py new file mode 100644 index 0000000000000000000000000000000000000000..60c5a72ed6b1cebbc97c7ccb64b5a2fc81548266 --- /dev/null +++ b/craft/vectorspace_fns_test.py @@ -0,0 +1,166 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for vectorspace_fns.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tracr.craft import bases +from tracr.craft import tests_common +from tracr.craft import vectorspace_fns as vs_fns + + +class LinearTest(tests_common.VectorFnTestCase): + + def test_identity_from_matrix(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) + f = vs_fns.Linear(vs, vs, np.eye(3)) + for v in vs.basis_vectors(): + self.assertEqual(f(v), v) + + def test_identity_from_action(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) + f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction) + for v in vs.basis_vectors(): + self.assertEqual(f(v), v) + + def test_nonidentiy(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a = vs.vector_from_basis_direction(bases.BasisDirection("a")) + b = vs.vector_from_basis_direction(bases.BasisDirection("b")) + + f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]])) + + self.assertEqual( + f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7]))) + self.assertEqual( + f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1]))) + + def test_different_vector_spaces(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) + a, b = vs1.basis_vectors() + c, d = vs2.basis_vectors() + + f = vs_fns.Linear(vs1, vs2, np.eye(2)) + + self.assertEqual(f(a), c) + self.assertEqual(f(b), d) + + def test_combining_linear_functions_with_different_input(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) + vs = bases.direct_sum(vs1, vs2) + a = vs.vector_from_basis_direction(bases.BasisDirection("a")) + b = vs.vector_from_basis_direction(bases.BasisDirection("b")) + c = vs.vector_from_basis_direction(bases.BasisDirection("c")) + d = vs.vector_from_basis_direction(bases.BasisDirection("d")) + + f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]])) + f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]])) + f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) + + self.assertEqual( + f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0]))) + self.assertEqual( + f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0]))) + self.assertEqual( + f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0]))) + self.assertEqual( + f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0]))) + + def test_combining_linear_functions_with_same_input(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a = vs.vector_from_basis_direction(bases.BasisDirection("a")) + b = vs.vector_from_basis_direction(bases.BasisDirection("b")) + + f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]])) + f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]])) + f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) + + self.assertEqual( + f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1]))) + self.assertEqual( + f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0]))) + self.assertEqual(f3(a), f1(a) + f2(a)) + self.assertEqual(f3(b), f1(b) + f2(b)) + + +class ProjectionTest(tests_common.VectorFnTestCase): + + def test_projection_to_larger_space(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + a1, b1 = vs1.basis_vectors() + a2, b2, _, _ = vs2.basis_vectors() + + f = vs_fns.project(vs1, vs2) + + self.assertEqual(f(a1), a2) + self.assertEqual(f(b1), b2) + + def test_projection_to_smaller_space(self): + vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) + vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a1, b1, c1, d1 = vs1.basis_vectors() + a2, b2 = vs2.basis_vectors() + + f = vs_fns.project(vs1, vs2) + + self.assertEqual(f(a1), a2) + self.assertEqual(f(b1), b2) + self.assertEqual(f(c1), vs2.null_vector()) + self.assertEqual(f(d1), vs2.null_vector()) + + +class ScalarBilinearTest(parameterized.TestCase): + + def test_identity_matrix(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a, b = vs.basis_vectors() + + f = vs_fns.ScalarBilinear(vs, vs, np.eye(2)) + + self.assertEqual(f(a, a), 1) + self.assertEqual(f(a, b), 0) + self.assertEqual(f(b, a), 0) + self.assertEqual(f(b, b), 1) + + def test_identity_from_action(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a, b = vs.basis_vectors() + + f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y)) + + self.assertEqual(f(a, a), 1) + self.assertEqual(f(a, b), 0) + self.assertEqual(f(b, a), 0) + self.assertEqual(f(b, b), 1) + + def test_non_identity(self): + vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) + a, b = vs.basis_vectors() + + f = vs_fns.ScalarBilinear.from_action(vs, vs, + lambda x, y: int(x.name == "a")) + + self.assertEqual(f(a, a), 1) + self.assertEqual(f(a, b), 1) + self.assertEqual(f(b, a), 0) + self.assertEqual(f(b, b), 0) + + +if __name__ == "__main__": + absltest.main() diff --git a/examples/Visualize_Tracr_Models.ipynb b/examples/Visualize_Tracr_Models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..338750abaedbf490d46e9d5b335be564b55914be --- /dev/null +++ b/examples/Visualize_Tracr_Models.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "99FBiGH7bsfn" + }, + "source": [ + "# Compiling \u0026 Visualizing Tracr Models\n", + "\n", + "This notebook demonstrates how to compile a tracr model and provides some tools visualize the model's residual stream or layer outputs for a given input sequence." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qm-PM1PEawCx" + }, + "outputs": [], + "source": [ + "#@title Imports\n", + "import jax\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# The default of float16 can lead to discrepancies between outputs of\n", + "# the compiled model and the RASP program.\n", + "jax.config.update('jax_default_matmul_precision', 'float32')\n", + "\n", + "from tracr.compiler import compiling\n", + "from tracr.compiler import lib\n", + "from tracr.rasp import rasp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HtOAc_yWawFR" + }, + "outputs": [], + "source": [ + "#@title Plotting functions\n", + "def tidy_label(label, value_width=5):\n", + " if ':' in label:\n", + " label, value = label.split(':')\n", + " else:\n", + " value = ''\n", + " return label + f\":{value:\u003e{value_width}}\"\n", + "\n", + "\n", + "def add_residual_ticks(model, value_width=5, x=False, y=True):\n", + " if y:\n", + " plt.yticks(\n", + " np.arange(len(model.residual_labels))+0.5, \n", + " [tidy_label(l, value_width=value_width)\n", + " for l in model.residual_labels], \n", + " family='monospace',\n", + " fontsize=20,\n", + " )\n", + " if x:\n", + " plt.xticks(\n", + " np.arange(len(model.residual_labels))+0.5, \n", + " [tidy_label(l, value_width=value_width)\n", + " for l in model.residual_labels], \n", + " family='monospace',\n", + " rotation=90,\n", + " fontsize=20,\n", + " )\n", + "\n", + "\n", + "def plot_computation_trace(model,\n", + " input_labels,\n", + " residuals_or_outputs,\n", + " add_input_layer=False,\n", + " figsize=(12, 9)):\n", + " fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)\n", + " value_width = max(map(len, map(str, input_labels))) + 1\n", + "\n", + " for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):\n", + " plt.sca(ax)\n", + " plt.pcolormesh(layer[0].T, vmin=0, vmax=1)\n", + " if i == 0:\n", + " add_residual_ticks(model, value_width=value_width)\n", + " plt.xticks(\n", + " np.arange(len(input_labels))+0.5,\n", + " input_labels,\n", + " rotation=90,\n", + " fontsize=20,\n", + " )\n", + " if add_input_layer and i == 0:\n", + " title = 'Input'\n", + " else:\n", + " layer_no = i - 1 if add_input_layer else i\n", + " layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'\n", + " title = f'{layer_type} {layer_no // 2 + 1}'\n", + " plt.title(title, fontsize=20)\n", + "\n", + "\n", + "def plot_residuals_and_input(model, inputs, figsize=(12, 9)):\n", + " \"\"\"Applies model to inputs, and plots the residual stream at each layer.\"\"\"\n", + " model_out = assembled_model.apply(inputs)\n", + " residuals = np.concatenate([model_out.input_embeddings[None, ...],\n", + " model_out.residuals], axis=0)\n", + " plot_computation_trace(\n", + " model=model,\n", + " input_labels=inputs,\n", + " residuals_or_outputs=residuals,\n", + " add_input_layer=True,\n", + " figsize=figsize)\n", + "\n", + "\n", + "def plot_layer_outputs(model, inputs, figsize=(12, 9)):\n", + " \"\"\"Applies model to inputs, and plots the outputs of each layer.\"\"\"\n", + " model_out = assembled_model.apply(inputs)\n", + " plot_computation_trace(\n", + " model=model,\n", + " input_labels=inputs,\n", + " residuals_or_outputs=model_out.layer_outputs,\n", + " add_input_layer=False,\n", + " figsize=figsize)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "8hV0nv_ISmhM" + }, + "outputs": [], + "source": [ + "#@title Define RASP programs\n", + "def get_program(program_name, max_seq_len):\n", + " \"\"\"Returns RASP program and corresponding token vocabulary.\"\"\"\n", + " if program_name == \"length\":\n", + " vocab = {\"a\", \"b\", \"c\", \"d\"}\n", + " program = lib.make_length()\n", + " elif program_name == \"frac_prevs\":\n", + " vocab = {\"a\", \"b\", \"c\", \"x\"}\n", + " program = lib.make_frac_prevs((rasp.tokens == \"x\").named(\"is_x\"))\n", + " elif program_name == \"dyck-2\":\n", + " vocab = {\"(\", \")\", \"{\", \"}\"}\n", + " program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\"])\n", + " elif program_name == \"dyck-3\":\n", + " vocab = {\"(\", \")\", \"{\", \"}\", \"[\", \"]\"}\n", + " program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\", \"[]\"])\n", + " elif program_name == \"sort\":\n", + " vocab = {1, 2, 3, 4, 5}\n", + " program = lib.make_sort(\n", + " rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)\n", + " elif program_name == \"sort_unique\":\n", + " vocab = {1, 2, 3, 4, 5}\n", + " program = lib.make_sort_unique(rasp.tokens, rasp.tokens)\n", + " elif program_name == \"hist\":\n", + " vocab = {\"a\", \"b\", \"c\", \"d\"}\n", + " program = lib.make_hist()\n", + " elif program_name == \"sort_freq\":\n", + " vocab = {\"a\", \"b\", \"c\", \"d\"}\n", + " program = lib.make_sort_freq(max_seq_len=max_seq_len)\n", + " elif program_name == \"pair_balance\":\n", + " vocab = {\"(\", \")\"}\n", + " program = lib.make_pair_balance(\n", + " sop=rasp.tokens, open_token=\"(\", close_token=\")\")\n", + " else:\n", + " raise NotImplementedError(f\"Program {program_name} not implemented.\")\n", + " return program, vocab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L_m_ufaua9ri" + }, + "outputs": [], + "source": [ + "#@title: Assemble model\n", + "program_name = \"sort_unique\" #@param [\"length\", \"frac_prevs\", \"dyck-2\", \"dyck-3\", \"sort\", \"sort_unique\", \"hist\", \"sort_freq\", \"pair_balance\"]\n", + "max_seq_len = 5 #@param {label: \"Test\", type: \"integer\"}\n", + "\n", + "program, vocab = get_program(program_name=program_name,\n", + " max_seq_len=max_seq_len)\n", + "\n", + "print(f\"Compiling...\")\n", + "print(f\" Program: {program_name}\")\n", + "print(f\" Input vocabulary: {vocab}\")\n", + "print(f\" Context size: {max_seq_len}\")\n", + "\n", + "assembled_model = compiling.compile_rasp_to_model(\n", + " program=program,\n", + " vocab=vocab,\n", + " max_seq_len=max_seq_len,\n", + " causal=False,\n", + " compiler_bos=\"bos\",\n", + " compiler_pad=\"pad\",\n", + " mlp_exactness=100)\n", + "\n", + "print(\"Done.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wtwiE-JiXF3F" + }, + "outputs": [], + "source": [ + "#@title Forward pass\n", + "assembled_model.apply([\"bos\", 3, 4, 1]).decoded" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RkEkVcEHa2gf" + }, + "outputs": [], + "source": [ + "#@title Plot residual stream\n", + "plot_residuals_and_input(\n", + " model=assembled_model,\n", + " inputs=[\"bos\", 3, 4, 1],\n", + " figsize=(10, 9)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8c4LakWHa4ey" + }, + "outputs": [], + "source": [ + "#@title Plot layer outputs\n", + "plot_layer_outputs(\n", + " model=assembled_model,\n", + " inputs = [\"bos\", 3, 4, 1],\n", + " figsize=(8, 9)\n", + ")" + ] + } + ], + "metadata": { + "colab": { + "private_outputs": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/rasp/causal_eval.py b/rasp/causal_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebc5dcea3724505bf13dbe5c2eb255eeeb2a227 --- /dev/null +++ b/rasp/causal_eval.py @@ -0,0 +1,39 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RASP Evaluator which applies causal masks to selectors.""" + +from typing import Sequence, Union + +import numpy as np +from tracr.rasp import rasp + + +class CausalEvaluator(rasp.DefaultRASPEvaluator): + """Evaluates RASP with causal masking.""" + + def evaluate( + self, expr: rasp.RASPExpr, xs: Sequence[rasp.Value] + ) -> Union[Sequence[rasp.Value], rasp.SelectorValue]: + out = super().evaluate(expr, xs) + + if not isinstance(expr, rasp.Selector): + return out + + out = np.array(out) + causal_mask = np.tril(np.full(out.shape, 1)) + return np.logical_and(causal_mask, out).tolist() + + +evaluate = CausalEvaluator().evaluate diff --git a/rasp/causal_eval_test.py b/rasp/causal_eval_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ad06721953e1ac849d8f8fbb3ef84b74bed1c0 --- /dev/null +++ b/rasp/causal_eval_test.py @@ -0,0 +1,61 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for causal_eval.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from tracr.rasp import causal_eval +from tracr.rasp import rasp + + +class CausalEvalTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="constant_selector_3x3_1", + program=rasp.ConstantSelector([ + [True, True, True], + [True, True, True], + [True, True, True], + ]), + input_sequence=[True, True, True], + expected_output=[ + [True, False, False], + [True, True, False], + [True, True, True], + ]), + dict( + testcase_name="constant_selector_3x3_2", + program=rasp.ConstantSelector([ + [True, True, True], + [False, True, True], + [True, False, True], + ]), + input_sequence=[True, True, True], + expected_output=[ + [True, False, False], + [False, True, False], + [True, False, True], + ])) + def test_evaluations(self, program, input_sequence, expected_output): + self.assertListEqual( + causal_eval.evaluate(program, input_sequence), + expected_output, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/rasp/rasp.py b/rasp/rasp.py new file mode 100644 index 0000000000000000000000000000000000000000..5569a5ed8fa1a2f5abba818b31a9a20be78d9021 --- /dev/null +++ b/rasp/rasp.py @@ -0,0 +1,932 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RASP program objects. + +Every object in the RASP language is a function. + +The most important type is S-Op, which is a function list[Value] -> list[Value]. + +An S-Op represents a state inside the residual stream of the transformer. +Therefore, any RASP program that represents a transformer computation must +define a final S-Op that represents the state of the residual stream at the +end of the computation. In particular, given an S-Op `x`, +`x([1, 2, 3])` represents something like the state of the residual stream +at location `x` when the transformer is fed [1, 2, 3] as input. + +A secondary (but still important) type is Selector, which is a function +list[Value] -> list[list[bool]]. Given a Selector `sel`, sel([1, 2, 3]) +represents something like an attention matrix in the transformer. + +For a full reference on RASP, see https://arxiv.org/abs/2106.06981. +""" + +import abc +import collections.abc +import copy +import enum +import functools +import itertools +from typing import (Any, Callable, Generic, Mapping, Optional, Protocol, + Sequence, TypeVar, Union) +from absl import logging + +import numpy as np + +SelectorValue = list[list[bool]] +NumericValue = Union[int, float] +Value = Union[None, int, float, str, bool] +VT = TypeVar("VT", bound=Value) +RASPExprT = TypeVar("RASPExprT", bound="RASPExpr") +SOpT = TypeVar("SOpT", bound="SOp") +T = TypeVar("T") + +_NAME_KEY = "name" +_ENCODING_KEY = "encoding" + +# These are run on every expression when it's initialised. +# Add your own annotators to this dict to add custom default annotations. +# +# For example, DEFAULT_ANNOTATORS['foo'] will provide the default value for +# expr.annotations['foo]. The annotator will get called lazily the first time +# that key is accessed. +# +# See the `default_name` annotator for a full example. +DEFAULT_ANNOTATORS: dict[str, "Annotator"] = {} + + +class Annotator(Protocol): + + def __call__(self, expr: "RASPExpr") -> Any: + """What annotation to add to `expr`.""" + + +class _Annotations(collections.abc.Mapping): + """Holds the expression's annotations. + + It's immutable to the user, but will attempt to generate default values + lazily when missing keys are requested. + """ + + def __init__(self, expr, **kwargs: Any): + self._expr = expr + self._inner_dict: dict[str, Any] = {**kwargs} + + def __getitem__(self, key: str) -> Any: + if key not in self._inner_dict: + if key not in DEFAULT_ANNOTATORS: + raise KeyError( + f"No annotation exists for key '{key}'. " + f"Available keys: {list(*self.keys(), *DEFAULT_ANNOTATORS.keys())}") + self._inner_dict[key] = DEFAULT_ANNOTATORS[key](self._expr) + + return self._inner_dict[key] + + def __iter__(self): + return iter(self._inner_dict) + + def __len__(self): + return len(self._inner_dict) + + +class RASPExpr(abc.ABC): + """A class distinguishing RASP expressions from other objects.""" + _ids = itertools.count(1) + + def __init__(self): + self._annotations: Mapping[str, Any] = _Annotations(self) + + @abc.abstractmethod + def __call__(self, + xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]: + """Evaluates the RASPExpr using the standard evaluator.""" + + @property + def annotations(self) -> Mapping[str, Any]: + """The annotations of this expression instance.""" + return self._annotations + + @annotations.setter + def annotations(self, annotations: Mapping[str, Any]): + self._annotations = _Annotations(self, **annotations) + + @property + def name(self) -> str: + """The name of this expression.""" + return self.annotations[_NAME_KEY] + + @property + @abc.abstractmethod + def children(self) -> Sequence["RASPExpr"]: + """Direct dependencies of this expression.""" + + @functools.cached_property + def unique_id(self): + """A unique id for every expression instance.""" + return next(self._ids) + + def copy(self: RASPExprT) -> RASPExprT: + """Returns a shallow copy of this RASPExpr with a new ID.""" + return copy.copy(self) + + @property + def label(self) -> str: + return f"{self.name}_{self.unique_id}" + + def named(self: RASPExprT, name: str) -> RASPExprT: + """Convenience method for adding a name.""" + return annotate(self, name=name) + + def annotated(self: RASPExprT, **annotations) -> RASPExprT: + """Convenience method for adding annotations.""" + return annotate(self, **annotations) + + +def annotate(expr: RASPExprT, **annotations) -> RASPExprT: + """Creates a new expr with added annotations.""" + new = expr.copy() + # Note that new annotations will overwrite existing ones with matching keys. + new.annotations = {**expr.annotations, **annotations} + return new + + +### S-Ops. + + +class SOp(RASPExpr): + """A Sequence Operation.""" + + def __call__(self, xs: Sequence[Value]) -> Sequence[Value]: + return evaluate(self, xs) # pytype: disable=bad-return-type + + # Allow construction of SOps using numeric operators with constant values. + # Note: if inheriting SOp by a dataclass, make sure to disable eq and order, + # as they will override these. + + def __lt__(self, other: Value) -> "SOp": + """self < other.""" + return Map(lambda x: x < other, self) + + def __le__(self, other: Value) -> "SOp": + """self <= other.""" + return Map(lambda x: x <= other, self) + + def __eq__(self, other: Value) -> "SOp": + """self == other.""" + return Map(lambda x: x == other, self) + + def __ne__(self, other: Value) -> "SOp": + """self != other.""" + return Map(lambda x: x != other, self) + + def __gt__(self, other: Value) -> "SOp": + """self > other.""" + return Map(lambda x: x > other, self) + + def __ge__(self, other: Value) -> "SOp": + """self >= other.""" + return Map(lambda x: x >= other, self) + + def __add__(self, other: Union["SOp", Value]) -> "SOp": + """self + other.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x + y, self, other) + return Map(lambda x: x + other, self) + + def __radd__(self, other: Union["SOp", Value]) -> "SOp": + """other + self.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x + y, other, self) + return Map(lambda x: other + x, self) + + def __sub__(self, other: Union["SOp", NumericValue]) -> "SOp": + """self - other.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x - y, self, other) + return Map(lambda x: x - other, self) + + def __rsub__(self, other: Union["SOp", NumericValue]) -> "SOp": + """other - self.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x - y, other, self) + return Map(lambda x: other - x, self) + + def __mul__(self, other: Union["SOp", NumericValue]) -> "SOp": + """self * other.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x * y, self, other) + return Map(lambda x: x * other, self) + + def __rmul__(self, other: Union["SOp", NumericValue]) -> "SOp": + """other * self.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x * y, other, self) + return Map(lambda x: other * x, self) + + def __truediv__(self, other: Union["SOp", NumericValue]) -> "SOp": + """self / other.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x / y, self, other) + return Map(lambda x: x / other, self) + + def __rtruediv__(self, other: Union["SOp", NumericValue]) -> "SOp": + """other / self.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x / y, other, self) + return Map(lambda x: other / x, self) + + def __invert__(self) -> "SOp": + return Map(lambda x: not x, self) + + def __and__(self, other: Union["SOp", NumericValue]) -> "SOp": + """self & other.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x and y, self, other) + return Map(lambda x: x and other, self) + + def __or__(self, other: Union["SOp", NumericValue]) -> "SOp": + """self | other.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x or y, self, other) + return Map(lambda x: x or other, self) + + def __rand__(self, other: Union["SOp", NumericValue]) -> "SOp": + """other & self.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x and y, other, self) + return Map(lambda x: other and x, self) + + def __ror__(self, other: Union["SOp", NumericValue]) -> "SOp": + """other | self.""" + if isinstance(other, SOp): + return SequenceMap(lambda x, y: x or y, other, self) + return Map(lambda x: x or other, self) + + +class TokensType(SOp): + """Primitive SOp returning the original input tokens.""" + + @property + def children(self) -> Sequence[RASPExpr]: + return [] + + @property + def label(self) -> str: + return "tokens" + + def __repr__(self): + return "tokens" + + +class IndicesType(SOp): + """Primitive SOp returning the position index at each token.""" + + @property + def children(self) -> Sequence[RASPExpr]: + return [] + + @property + def label(self) -> str: + return "indices" + + def __repr__(self): + return "indices" + + +class LengthType(SOp): + """Primitive SOp returning the total length of the input.""" + + @property + def children(self) -> Sequence[RASPExpr]: + return [] + + @property + def label(self) -> str: + return "length" + + def __repr__(self): + return "length" + + +tokens = TokensType() +indices = IndicesType() +length = LengthType() + + +class Map(SOp): + """SOp that evaluates the function elementwise on the input SOp. + + Map(lambda x: x + 1, tokens).eval([1, 2, 3]) == [2, 3, 4] + """ + + def __init__(self, f: Callable[[Value], Value], inner: SOp): + super().__init__() + self.f = f + self.inner = inner + + assert isinstance(self.inner, SOp) + assert callable(self.f) and not isinstance(self.f, RASPExpr) + + if isinstance(self.inner, Map): + # Combine the functions into just one. + inner_f = self.inner.f + self.f = lambda t: f(inner_f(t)) + self.inner = self.inner.inner + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.inner] + + +class SequenceMap(SOp): + """SOp that evaluates the function elementwise on the two given SOp's. + + SequenceMap(lambda x, y: x - y, length, tokens).eval([1, 2, 3]) == [2, 1, 0] + """ + + def __init__(self, f: Callable[[Value, Value], Value], fst: SOp, snd: SOp): + super().__init__() + + if fst == snd: + logging.warning("Creating a SequenceMap with both inputs being the same " + "SOp is discouraged. You should use a Map instead.") + + self.f = f + self.fst = fst + self.snd = snd + assert isinstance(self.fst, SOp) + assert isinstance(self.snd, SOp) + assert callable(self.f) and not isinstance(self.f, RASPExpr) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.fst, self.snd] + + +class LinearSequenceMap(SequenceMap): + """SOp that evaluates a linear function elementwise on the two given SOp's.""" + + def __init__(self, fst: SOp, snd: SOp, fst_fac: float, snd_fac: float): + super().__init__(fst=fst, snd=snd, f=lambda x, y: fst_fac * x + snd_fac * y) + self.fst_fac = fst_fac + self.snd_fac = snd_fac + + +class Full(SOp): + """A SOp evaluating to [fill]*len(input_values).""" + + def __init__(self, fill: Value): + super().__init__() + self.fill = fill + + @property + def children(self) -> Sequence[RASPExpr]: + return [] + + +def sop_not(sop: SOp) -> SOp: + return Map(lambda t: not t, sop) + + +class ConstantSOp(SOp, Generic[VT]): + """A constant S-Op for testing purposes.""" + + def __init__(self, value: Sequence[VT], check_length: bool = True): + super().__init__() + self.value = value + self.check_length = check_length + + @property + def children(self) -> Sequence[RASPExpr]: + return [] + + +### Selectors. + + +class Predicate(Protocol): + + def __call__(self, key: Value, query: Value) -> bool: + """Applies the predicate.""" + + +class Comparison(enum.Enum): + """A two-place boolean comparison predicate for use in Select.""" + EQ = "==" + LT = "<" + LEQ = "<=" + GT = ">" + GEQ = ">=" + NEQ = "!=" + TRUE = "True" + FALSE = "False" + + def __call__(self, key: Value, query: Value) -> bool: + if key is None: + raise ValueError("key is None!") + if query is None: + raise ValueError("query is None!") + return _comparison_table[self](key, query) + + +_comparison_table = { + Comparison.EQ: lambda key, query: key == query, + Comparison.LT: lambda key, query: key < query, + Comparison.LEQ: lambda key, query: key <= query, + Comparison.GT: lambda key, query: key > query, + Comparison.GEQ: lambda key, query: key >= query, + Comparison.NEQ: lambda key, query: key != query, + Comparison.TRUE: lambda key, query: True, + Comparison.FALSE: lambda key, query: False, +} + + +class Selector(RASPExpr): + """RASP Selector. Represents something like an attention head's weights.""" + + def __call__(self, xs: Sequence[Value]) -> SelectorValue: + return evaluate(self, xs) # pytype: disable=bad-return-type + + # Allow construction of Selector combinations using Python logical operators. + def __and__(self, other: "Selector") -> "Selector": + """self & other.""" + return selector_and(self, other) + + def __rand__(self, other: "Selector") -> "Selector": + """other & self.""" + return selector_and(other, self) + + def __or__(self, other: "Selector") -> "Selector": + """self | other.""" + return selector_or(self, other) + + def __ror__(self, other: "Selector") -> "Selector": + """other | self.""" + return selector_or(other, self) + + def __invert__(self) -> "Selector": + """~self.""" + return selector_not(self) + + +class Select(Selector): + """Primitive that creates a Selector.""" + + def __init__(self, keys: SOp, queries: SOp, predicate: Predicate): + super().__init__() + self.keys = keys + self.queries = queries + self.predicate = predicate + assert isinstance(self.keys, SOp) + assert isinstance(self.queries, SOp) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.keys, self.queries] + + +class ConstantSelector(Selector): + """A constant selector for testing purposes.""" + + def __init__(self, value: SelectorValue, check_length: bool = True): + super().__init__() + self.value = value + self.check_length = check_length + + @property + def children(self) -> Sequence[RASPExpr]: + return [] + + +class SelectorWidth(SOp): + """SelectorWidth primitive.""" + + def __init__(self, selector: Selector): + super().__init__() + self.selector = selector + assert isinstance(self.selector, Selector) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.selector] + + +class SelectorAnd(Selector): + """Implements elementwise `and` between selectors.""" + + def __init__(self, fst: Selector, snd: Selector): + super().__init__() + self.fst = fst + self.snd = snd + assert isinstance(self.fst, Selector) + assert isinstance(self.snd, Selector) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.fst, self.snd] + + +class SelectorOr(Selector): + """Implements elementwise `or` between selectors.""" + + def __init__(self, fst: Selector, snd: Selector): + super().__init__() + self.fst = fst + self.snd = snd + assert isinstance(self.fst, Selector) + assert isinstance(self.snd, Selector) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.fst, self.snd] + + +class SelectorNot(Selector): + """Implements elementwise `not` on a selector.""" + + def __init__(self, inner: Selector): + self.inner = inner + super().__init__() + assert isinstance(self.inner, Selector) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.inner] + + +def selector_not( + inner: Selector, + simplify: bool = True, +) -> Selector: + """Returns a SelectorNot, or a Select if simplifying is possible.""" + if simplify and isinstance(inner, Select): + predicate = lambda k, q: not inner.predicate(k, q) + return Select(inner.keys, inner.queries, predicate=predicate) + + return SelectorNot(inner) + + +def selector_and( + fst: Selector, + snd: Selector, + simplify: bool = True, +) -> Selector: + """Returns a SelectorAnd, or a Select if simplifying is possible.""" + if simplify and isinstance(fst, Select) and isinstance(snd, Select): + simplified = _attempt_simplify(fst, snd, lambda l, r: l and r) + if simplified: + return simplified + + return SelectorAnd(fst, snd) + + +def selector_or( + fst: Selector, + snd: Selector, + simplify: bool = True, +) -> Selector: + """Returns a SelectorOr, or a Select if simplifying is possible.""" + if simplify and isinstance(fst, Select) and isinstance(snd, Select): + simplified = _attempt_simplify(fst, snd, lambda l, r: l or r) + if simplified: + return simplified + + return SelectorOr(fst, snd) + + +def _attempt_simplify( + fst: Select, + snd: Select, + combine: Callable[[bool, bool], bool], +) -> Optional[Select]: + """Simplifies two Selects if possible. + + If two Selects in a compound Selector have matching keys and queries, they can + be simplified into one Select with a compound predicate: + + lambda k,q: combine(fst.predicate(k,q), snd.predicate(k,q)) + + This function returns a Select with this predicate if possible, + and None otherwise. + + A Full SOp in a key or query position is a special case that always matches + any SOp in the corresponding position in the other selector. In that case, + we bake in the fill value into the corresponding Select's predicate before + combining. This allows us to use the other SOp as the input to the simplified + Select. + + Args: + fst: the first Select. + snd: the second Select. + combine: how to combine the outputs of the individual predicates. + + Returns: + A combined Select, if possible. + """ + fst_predicate = fst.predicate + snd_predicate = snd.predicate + common_keys = None + common_queries = None + + if isinstance(fst.keys, Full): + common_keys = snd.keys + # We pass the predicate in as a default arg to avoid unintended recursion. + fst_predicate = lambda key, query, p=fst_predicate: p(fst.keys.fill, query) + if isinstance(snd.keys, Full): + common_keys = fst.keys + snd_predicate = lambda key, query, p=snd_predicate: p(snd.keys.fill, query) + if isinstance(fst.queries, Full): + common_queries = snd.queries + fst_predicate = lambda key, query, p=fst_predicate: p(key, fst.queries.fill) + if isinstance(snd.queries, Full): + common_queries = fst.queries + snd_predicate = lambda key, query, p=snd_predicate: p(key, snd.queries.fill) + if fst.keys is snd.keys: + common_keys = fst.keys + if fst.queries is snd.queries: + common_queries = fst.queries + + if not common_keys or not common_queries: + return None + + def predicate(key, query): + return combine(fst_predicate(key, query), snd_predicate(key, query)) + + return Select(common_keys, common_queries, predicate=predicate) + + +class Aggregate(SOp, Generic[VT]): + """Aggregate primitive.""" + + def __init__(self, + selector: Selector, + sop: SOp, + default: Optional[VT] = None): + """Initialises. The default is used where nothing is selected.""" + super().__init__() + self.selector = selector + self.sop = sop + self.default = default + assert isinstance(self.selector, Selector) + assert isinstance(self.sop, SOp) + assert (self.default is None or isinstance(self.default, + (str, float, bool, int))) + + @property + def children(self) -> Sequence[RASPExpr]: + return [self.selector, self.sop] + + +### SOp encodings. + + +class Encoding(enum.Enum): + """The encoding used by a SOp. Only number-valued SOps support numerical.""" + CATEGORICAL = "categorical" + NUMERICAL = "numerical" + + +def numerical(sop: SOpT) -> SOpT: + return annotate(sop, encoding=Encoding.NUMERICAL) + + +def categorical(sop: SOpT) -> SOpT: + return annotate(sop, encoding=Encoding.CATEGORICAL) + + +def get_encoding(sop: SOp) -> Encoding: + return sop.annotations["encoding"] + + +def is_numerical(sop: SOp) -> bool: + """Check if the SOp is numerically encoded.""" + return get_encoding(sop) == Encoding.NUMERICAL + + +def is_categorical(sop: SOp) -> bool: + """Check if the SOp is categorically encoded.""" + return get_encoding(sop) == Encoding.CATEGORICAL + + +def default_encoding(expr: RASPExpr) -> Optional[Encoding]: + """Adds an 'encoding' annotation, default is Categorical.""" + if not isinstance(expr, SOp): + raise TypeError(f"expr {expr} is not a SOp.") + + return Encoding.CATEGORICAL + + +DEFAULT_ANNOTATORS[_ENCODING_KEY] = default_encoding + +### naming. + +# Subclasses must appear here before superclasses in order for +# the most specific entry to be used. + +_default_name_by_class = { + # Primitives + TokensType: "tokens", + IndicesType: "indices", + LengthType: "length", + # SOps + LinearSequenceMap: "linear_sequence_map", + SequenceMap: "sequence_map", + Map: "map", + Full: "full", + ConstantSOp: "constant_sop", + SelectorWidth: "selector_width", + Aggregate: "aggregate", + SOp: "sop", + # Selectors + Select: "select", + SelectorAnd: "selector_and", + SelectorOr: "selector_or", + SelectorNot: "selector_not", + ConstantSelector: "constant_selector", + Selector: "selector", +} + + +def default_name(expr: RASPExpr) -> dict[str, str]: + for cls, name in _default_name_by_class.items(): + if isinstance(expr, cls): + return name + + raise NotImplementedError(f"{expr} was not given a default name!") + + +DEFAULT_ANNOTATORS[_NAME_KEY] = default_name + +### evaluation. + + +class RASPEvaluator(abc.ABC): + """ABC for RASP evaluators.""" + + @abc.abstractmethod + def evaluate(self, expr: RASPExpr, + xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]: + """Evaluates the RASP expression on input `xs`.""" + + +class DefaultRASPEvaluator(abc.ABC): + """Default evaluator for RASP.""" + + def evaluate(self, expr: RASPExpr, + xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]: + """Evaluates the RASP expression on input `xs`.""" + return self._eval_fn_by_expr_type[type(expr)](expr, xs) + + def __init__(self): + self._eval_fn_by_expr_type = { + # Primitives + TokensType: self.eval_tokens, + IndicesType: self.eval_indices, + LengthType: self.eval_length, + # SOps + LinearSequenceMap: self.eval_sequence_map, + SequenceMap: self.eval_sequence_map, + Map: self.eval_map, + Full: self.eval_full, + ConstantSOp: self.eval_constant_sop, + SelectorWidth: self.eval_selector_width, + Aggregate: self.eval_aggregate, + SOp: _raise_not_implemented, + # Selectors + Select: self.eval_select, + SelectorAnd: self.eval_selector_and, + SelectorOr: self.eval_selector_or, + SelectorNot: self.eval_selector_not, + ConstantSelector: self.eval_constant_selector, + Selector: _raise_not_implemented, + } + + def eval_tokens(self, sop: TokensType, + xs: Sequence[Value]) -> Sequence[Value]: + del sop + return list(xs) + + def eval_indices(self, sop: IndicesType, + xs: Sequence[Value]) -> Sequence[Value]: + del sop + return list(range(len(xs))) + + def eval_length(self, sop: LengthType, xs: Sequence[Value]) -> Sequence[int]: + del sop + return [len(xs)] * len(xs) + + def eval_sequence_map(self, sop: SequenceMap, + xs: Sequence[Value]) -> Sequence[Value]: + fst_values = self.evaluate(sop.fst, xs) + snd_values = self.evaluate(sop.snd, xs) + return [ + sop.f(x, y) if None not in [x, y] else None + for x, y in zip(fst_values, snd_values) + ] + + def eval_map(self, sop: Map, xs: Sequence[Value]) -> Sequence[Value]: + return [ + sop.f(x) if x is not None else None + for x in self.evaluate(sop.inner, xs) + ] + + def eval_full(self, sop: Full, xs: Sequence[Value]) -> Sequence[Value]: + return [sop.fill] * len(xs) + + def eval_constant_sop(self, sop: ConstantSOp, + xs: Sequence[Value]) -> Sequence[Value]: + if sop.check_length and (len(xs) != len(sop.value)): + raise ValueError( + f"Constant len {len(sop.value)} doesn't match input len {len(xs)}.") + return sop.value + + def eval_selector_width(self, sop: SelectorWidth, + xs: Sequence[Value]) -> Sequence[Value]: + selector_values = self.evaluate(sop.selector, xs) + return [sum(row) for row in selector_values] + + def eval_aggregate(self, sop: Aggregate, + xs: Sequence[Value]) -> Sequence[Value]: + selector_value = self.evaluate(sop.selector, xs) + values = self.evaluate(sop.sop, xs) + default = sop.default + + return [ + _mean(_get_selected(row, values), default) for row in selector_value + ] + + def eval_select(self, sel: Select, xs: Sequence[Value]) -> SelectorValue: + """Evaluates a Select on `xs`.""" + key_values = self.evaluate(sel.keys, xs) + query_values = self.evaluate(sel.queries, xs) + + key_len = len(key_values) + query_len = len(query_values) + out = np.zeros((query_len, key_len), dtype=bool).tolist() + for row, query in enumerate(query_values): + for col, key in enumerate(key_values): + out[row][col] = bool(sel.predicate(key, query)) + return out + + def eval_constant_selector(self, sel: ConstantSelector, + xs: Sequence[Value]) -> SelectorValue: + if sel.check_length and (len(xs) != len(sel.value)): + raise ValueError( + f"Constant len {len(xs)} doesn't match input len {len(sel.value)}.") + return sel.value + + def eval_selector_and(self, sel: SelectorAnd, + xs: Sequence[Value]) -> SelectorValue: + fst_values = self.evaluate(sel.fst, xs) + snd_values = self.evaluate(sel.snd, xs) + return np.logical_and(np.array(fst_values), np.array(snd_values)).tolist() + + def eval_selector_or(self, sel: SelectorOr, + xs: Sequence[Value]) -> SelectorValue: + fst_values = self.evaluate(sel.fst, xs) + snd_values = self.evaluate(sel.snd, xs) + return np.logical_or(np.array(fst_values), np.array(snd_values)).tolist() + + def eval_selector_not(self, sel: SelectorNot, + xs: Sequence[Value]) -> SelectorValue: + values = self.evaluate(sel.inner, xs) + return np.logical_not(np.array(values)).tolist() + + +def _get_selected( + selector_row: list[bool], + values: Sequence[VT], +) -> Sequence[VT]: + """Helper for aggregate. [T T F], [a b c] -> [a b].""" + return [v for s, v in zip(selector_row, values) if s] + + +def _mean(xs: Sequence[VT], default: VT) -> VT: + """Takes the mean for numbers and concats for strings.""" + if not xs: + return default + exemplar = xs[0] + if isinstance(exemplar, (int, bool)): + return sum(xs) / len(xs) + elif len(xs) == 1: + return exemplar + else: + raise ValueError(f"Unsupported type for aggregation: {xs}") + + +def _raise_not_implemented(expr: RASPExpr, xs: Sequence[Value]): + raise NotImplementedError(f"Evaluation of {expr} is not defined.") + + +evaluate = DefaultRASPEvaluator().evaluate diff --git a/rasp/rasp_test.py b/rasp/rasp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f73eaf2c1a665e547ad53c5f8631bc2d67951f7d --- /dev/null +++ b/rasp/rasp_test.py @@ -0,0 +1,580 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rasp.rasp.""" + +import itertools + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tracr.rasp import rasp + +# Note that the example text labels must match their default names. + +_SOP_PRIMITIVE_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda + ("tokens", rasp.tokens), + ("length", rasp.length), + ("indices", rasp.indices), +] + +_NONPRIMITIVE_SOP_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda + ("map", rasp.Map(lambda x: x, rasp.tokens)), + ( + "sequence_map", + rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens), + ), + ( + "linear_sequence_map", + rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, 0.1, 0.2), + ), + ( + "aggregate", + rasp.Aggregate( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.tokens, + ), + ), + ( + "selector_width", + rasp.SelectorWidth( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)), + ), +] + +_SOP_EXAMPLES = lambda: _SOP_PRIMITIVE_EXAMPLES() + _NONPRIMITIVE_SOP_EXAMPLES() + +_SELECTOR_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda + ("select", rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)), + ("selector_and", + rasp.SelectorAnd( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ), + )), + ("selector_or", + rasp.SelectorOr( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), + rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ), + )), + ("selector_not", + rasp.SelectorNot( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),)), +] + +_ALL_EXAMPLES = lambda: _SOP_EXAMPLES() + _SELECTOR_EXAMPLES() + + +class LabelTest(parameterized.TestCase): + + def test_primitive_labels(self): + self.assertEqual(rasp.tokens.label, "tokens") + self.assertEqual(rasp.indices.label, "indices") + self.assertEqual(rasp.length.label, "length") + + @parameterized.parameters(*_ALL_EXAMPLES()) + def test_default_names(self, default_name: str, expr: rasp.RASPExpr): + self.assertEqual(expr.name, default_name) + + +class SOpTest(parameterized.TestCase): + """Tests for S-Ops.""" + + @parameterized.parameters( + ("hello", ["h", "e", "l", "l", "o"]), + ("h", ["h"]), + (["h", "e", "l", "l", "o"], ["h", "e", "l", "l", "o"]), + (["h"], ["h"]), + ([1, 2], [1, 2]), + ([0.1, 0.2], [0.1, 0.2]), + ) + def test_tokens(self, input_sequence, expected): + self.assertEqual(rasp.tokens(input_sequence), expected) + + @parameterized.parameters( + ("hello", [0, 1, 2, 3, 4]), + ("h", [0]), + (["h", "e", "l", "l", "o"], [0, 1, 2, 3, 4]), + (["h"], [0]), + ([1, 2], [0, 1]), + ([0.1, 0.2], [0, 1]), + ) + def test_indices(self, input_sequence, expected): + self.assertEqual(rasp.indices(input_sequence), expected) + + @parameterized.parameters( + ("hello", [5, 5, 5, 5, 5]), + ("h", [1]), + (["h", "e", "l", "l", "o"], [5, 5, 5, 5, 5]), + (["h"], [1]), + ([1, 2], [2, 2]), + ([0.1, 0.2], [2, 2]), + ) + def test_length(self, input_sequence, expected): + self.assertEqual(rasp.length(input_sequence), expected) + + def test_prims_are_sops(self): + self.assertIsInstance(rasp.tokens, rasp.SOp) + self.assertIsInstance(rasp.indices, rasp.SOp) + self.assertIsInstance(rasp.length, rasp.SOp) + + def test_prims_are_raspexprs(self): + self.assertIsInstance(rasp.tokens, rasp.RASPExpr) + self.assertIsInstance(rasp.indices, rasp.RASPExpr) + self.assertIsInstance(rasp.length, rasp.RASPExpr) + + @parameterized.parameters( + (lambda x: x + "a", "hello", ["ha", "ea", "la", "la", "oa"]), + (lambda x: x + "t", "h", ["ht"]), + (lambda x: x + 1, [1, 2], [2, 3]), + (lambda x: x / 2, [0.1, 0.2], [0.05, 0.1]), + ) + def test_map(self, f, input_sequence, expected): + self.assertEqual(rasp.Map(f, rasp.tokens)(input_sequence), expected) + + def test_nested_elementwise_ops_results_in_only_one_map_object(self): + map_sop = ((rasp.tokens * 2) + 2) / 2 + self.assertEqual(map_sop.inner, rasp.tokens) + self.assertEqual(map_sop([1]), [2]) + + @parameterized.parameters( + (lambda x, y: x + y, "hello", ["hh", "ee", "ll", "ll", "oo"]), + (lambda x, y: x + y, "h", ["hh"]), + (lambda x, y: x + y, [1, 2], [2, 4]), + (lambda x, y: x * y, [1, 2], [1, 4]), + ) + def test_sequence_map(self, f, input_sequence, expected): + self.assertEqual( + rasp.SequenceMap(f, rasp.tokens, rasp.tokens)(input_sequence), expected) + + def test_sequence_map_with_same_inputs_logs_warning(self): + with self.assertLogs(level="WARNING"): + rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens) + + @parameterized.parameters( + (1, 1, [1, 2], [2, 4]), + (1, -1, [1, 2], [0, 0]), + (1, -2, [1, 2], [-1, -2]), + ) + def test_linear_sequence_map(self, fst_fac, snd_fac, input_sequence, + expected): + self.assertEqual( + rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, fst_fac, + snd_fac)(input_sequence), expected) + + @parameterized.parameters( + ([5, 5, 5, 5, 5], "hello", [5, 5, 5, 5, 5]), + (["e"], "h", ["e"]), + ([1, 2, 3, 4, 5], ["h", "e", "l", "l", "o"], [1, 2, 3, 4, 5]), + ([2, 2], [1, 2], [2, 2]), + ) + def test_constant(self, const, input_sequence, expected): + self.assertEqual(rasp.ConstantSOp(const)(input_sequence), expected) + + def test_constant_complains_if_sizes_dont_match(self): + with self.assertRaisesRegex( + ValueError, + r"^.*Constant len .* doesn't match input len .*$",): + rasp.ConstantSOp([1, 2, 3])("longer string") + + def test_can_turn_off_constant_complaints(self): + rasp.ConstantSOp([1, 2, 3], check_length=False)("longer string") + + def test_numeric_dunders(self): + # We don't check all the cases here -- only a few representative ones. + self.assertEqual( + (rasp.tokens > 1)([0, 1, 2]), + [0, 0, 1], + ) + self.assertEqual( + (1 < rasp.tokens)([0, 1, 2]), + [0, 0, 1], + ) + self.assertEqual( + (rasp.tokens < 1)([0, 1, 2]), + [1, 0, 0], + ) + self.assertEqual( + (1 > rasp.tokens)([0, 1, 2]), + [1, 0, 0], + ) + self.assertEqual( + (rasp.tokens == 1)([0, 1, 2]), + [0, 1, 0], + ) + self.assertEqual( + (rasp.tokens + 1)([0, 1, 2]), + [1, 2, 3], + ) + self.assertEqual( + (1 + rasp.tokens)([0, 1, 2]), + [1, 2, 3], + ) + + def test_dunders_with_sop(self): + self.assertEqual( + (rasp.tokens + rasp.indices)([0, 1, 2]), + [0, 2, 4], + ) + self.assertEqual( + (rasp.length - 1 - rasp.indices)([0, 1, 2]), + [2, 1, 0], + ) + self.assertEqual( + (rasp.length * rasp.length)([0, 1, 2]), + [9, 9, 9], + ) + + def test_logical_dunders(self): + self.assertEqual( + (rasp.tokens & True)([True, False]), + [True, False], + ) + self.assertEqual( + (rasp.tokens & False)([True, False]), + [False, False], + ) + self.assertEqual( + (rasp.tokens | True)([True, False]), + [True, True], + ) + self.assertEqual( + (rasp.tokens | False)([True, False]), + [True, False], + ) + self.assertEqual( + (True & rasp.tokens)([True, False]), + [True, False], + ) + self.assertEqual( + (False & rasp.tokens)([True, False]), + [False, False], + ) + self.assertEqual( + (True | rasp.tokens)([True, False]), + [True, True], + ) + self.assertEqual( + (False | rasp.tokens)([True, False]), + [True, False], + ) + + self.assertEqual( + (~rasp.tokens)([True, False]), + [False, True], + ) + + self.assertEqual( + (rasp.ConstantSOp([True, True, False, False]) + & rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]), + [True, False, False, False], + ) + + self.assertEqual( + (rasp.ConstantSOp([True, True, False, False]) + | rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]), + [True, True, True, False], + ) + + +class EncodingTest(parameterized.TestCase): + """Tests for SOp encodings.""" + + @parameterized.named_parameters(*_SOP_EXAMPLES()) + def test_all_sops_are_categorical_by_default(self, sop: rasp.SOp): + self.assertTrue(rasp.is_categorical(sop)) + + @parameterized.named_parameters(*_SOP_EXAMPLES()) + def test_is_numerical(self, sop: rasp.SOp): + self.assertTrue(rasp.is_numerical(rasp.numerical(sop))) + self.assertFalse(rasp.is_numerical(rasp.categorical(sop))) + + @parameterized.named_parameters(*_SOP_EXAMPLES()) + def test_is_categorical(self, sop: rasp.SOp): + self.assertTrue(rasp.is_categorical(rasp.categorical(sop))) + self.assertFalse(rasp.is_categorical(rasp.numerical(sop))) + + @parameterized.named_parameters(*_SOP_EXAMPLES()) + def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp): + num_sop = rasp.numerical(sop) + cat_num_sop = rasp.categorical(num_sop) + self.assertTrue(rasp.is_numerical(num_sop)) + self.assertTrue(rasp.is_categorical(cat_num_sop)) + + +class SelectorTest(parameterized.TestCase): + """Tests for Selectors.""" + + def test_select_eq_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) + self.assertEqual( + selector("hey"), [ + [True, False, False], + [False, True, False], + [False, False, True], + ]) + + def test_select_lt_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT) + self.assertEqual(selector([0, 1]), [ + [False, False], + [True, False], + ]) + + def test_select_leq_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ) + self.assertEqual(selector([0, 1]), [ + [True, False], + [True, True], + ]) + + def test_select_gt_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GT) + self.assertEqual(selector([0, 1]), [ + [False, True], + [False, False], + ]) + + def test_select_geq_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GEQ) + self.assertEqual(selector([0, 1]), [ + [True, True], + [False, True], + ]) + + def test_select_neq_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.NEQ) + self.assertEqual(selector([0, 1]), [ + [False, True], + [True, False], + ]) + + def test_select_true_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE) + self.assertEqual(selector([0, 1]), [ + [True, True], + [True, True], + ]) + + def test_select_false_has_correct_value(self): + selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.FALSE) + self.assertEqual(selector([0, 1]), [ + [False, False], + [False, False], + ]) + + def test_selector_and_gets_simplified_when_keys_and_queries_match(self): + selector = rasp.selector_and( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ), + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ), + ) + self.assertIsInstance(selector, rasp.Select) + self.assertIs(selector.keys, rasp.tokens) + self.assertIs(selector.queries, rasp.indices) + + def test_selector_and_doesnt_get_simplified_when_keys_queries_different(self): + selector = rasp.selector_and( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ), + rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ), + ) + self.assertIsInstance(selector, rasp.SelectorAnd) + + def test_selector_and_gets_simplified_when_keys_are_full(self): + selector = rasp.selector_and( + rasp.Select(rasp.Full(1), rasp.indices, rasp.Comparison.GEQ), + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ), + ) + self.assertIsInstance(selector, rasp.Select) + self.assertIs(selector.keys, rasp.tokens) + self.assertIs(selector.queries, rasp.indices) + + def test_selector_and_gets_simplified_when_queries_are_full(self): + selector = rasp.selector_and( + rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ), + rasp.Select(rasp.tokens, rasp.Full(1), rasp.Comparison.LEQ), + ) + self.assertIsInstance(selector, rasp.Select) + self.assertIs(selector.keys, rasp.tokens) + self.assertIs(selector.queries, rasp.indices) + + @parameterized.parameters( + itertools.product( + (rasp.tokens, rasp.indices, rasp.Full(1)), + (rasp.tokens, rasp.indices, rasp.Full(1)), + list(rasp.Comparison), + (rasp.tokens, rasp.indices, rasp.Full(1)), + (rasp.tokens, rasp.indices, rasp.Full(1)), + list(rasp.Comparison), + )) + def test_simplified_selector_and_works_the_same_way_as_not( + self, fst_k, fst_q, fst_p, snd_k, snd_q, snd_p): + fst = rasp.Select(fst_k, fst_q, fst_p) + snd = rasp.Select(snd_k, snd_q, snd_p) + + simplified = rasp.selector_and(fst, snd)([0, 1, 2, 3]) + not_simplified = rasp.selector_and(fst, snd, simplify=False)([0, 1, 2, 3]) + + np.testing.assert_array_equal( + np.array(simplified), + np.array(not_simplified), + ) + + def test_select_is_selector(self): + self.assertIsInstance( + rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), + rasp.Selector, + ) + + def test_select_is_raspexpr(self): + self.assertIsInstance( + rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), + rasp.RASPExpr, + ) + + def test_constant_selector(self): + self.assertEqual( + rasp.ConstantSelector([[True, True], [False, False]])([1, 2]), + [[True, True], [False, False]], + ) + + +class CopyTest(parameterized.TestCase): + + @parameterized.named_parameters(*_ALL_EXAMPLES()) + def test_copy_preserves_name(self, expr: rasp.RASPExpr): + expr = expr.named("foo") + self.assertEqual(expr.copy().name, expr.name) + + @parameterized.named_parameters(*_ALL_EXAMPLES()) + def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr): + expr = expr.named("foo") + expr.copy().named("bar") + self.assertEqual(expr.name, "foo") + + @parameterized.named_parameters(*_ALL_EXAMPLES()) + def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr): + expr = expr.named("foo") + copy = expr.copy() + expr.named("bar") + self.assertEqual(copy.name, "foo") + + @parameterized.named_parameters(*_ALL_EXAMPLES()) + def test_copy_changes_id(self, expr: rasp.RASPExpr): + self.assertNotEqual(expr.copy().unique_id, expr.unique_id) + + @parameterized.named_parameters(*_ALL_EXAMPLES()) + def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr): + copy_child_ids = [c.unique_id for c in expr.copy().children] + child_ids = [c.unique_id for c in expr.children] + for child_id, copy_child_id in zip(child_ids, copy_child_ids): + self.assertEqual(child_id, copy_child_id) + + +class AggregateTest(parameterized.TestCase): + """Tests for Aggregate.""" + + @parameterized.parameters( + dict( + selector=rasp.ConstantSelector([ + [True, False], + [False, True], + ]), + sop=rasp.ConstantSOp(["h", "e"]), + default=None, + expected_value=["h", "e"], + ), + dict( + selector=rasp.ConstantSelector([ + [False, True], + [False, False], + ]), + sop=rasp.ConstantSOp(["h", "e"]), + default=None, + expected_value=["e", None], + ), + dict( + selector=rasp.ConstantSelector([ + [True, False], + [False, False], + ]), + sop=rasp.ConstantSOp(["h", "e"]), + default=None, + expected_value=["h", None], + ), + dict( + selector=rasp.ConstantSelector([ + [True, True], + [False, True], + ]), + sop=rasp.ConstantSOp([0, 1]), + default=0, + expected_value=[0.5, 1], + ), + dict( + selector=rasp.ConstantSelector([ + [False, False], + [True, True], + ]), + sop=rasp.ConstantSOp([0, 1]), + default=0, + expected_value=[0, 0.5], + ), + dict( + selector=rasp.ConstantSelector([ + [False, False], + [True, True], + ]), + sop=rasp.ConstantSOp([0, 1]), + default=None, + expected_value=[None, 0.5], + ), + ) + def test_aggregate_on_size_2_inputs(self, selector, sop, default, + expected_value): + # The 0, 0 input is ignored as it's overridden by the constant SOps. + self.assertEqual( + rasp.Aggregate(selector, sop, default)([0, 0]), + expected_value, + ) + + +class RaspProgramTest(parameterized.TestCase): + """Each testcase implements and tests a RASP program.""" + + def test_has_prev(self): + + def has_prev(seq: rasp.SOp) -> rasp.SOp: + prev_copy = rasp.SelectorAnd( + rasp.Select(seq, seq, rasp.Comparison.EQ), + rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LT), + ) + return rasp.Aggregate(prev_copy, rasp.Full(1), default=0) > 0 + + self.assertEqual( + has_prev(rasp.tokens)("hello"), + [0, 0, 0, 1, 0], + ) + + self.assertEqual( + has_prev(rasp.tokens)("helllo"), + [0, 0, 0, 1, 1, 0], + ) + + self.assertEqual( + has_prev(rasp.tokens)([0, 2, 3, 2, 1, 0, 2]), + [0, 0, 0, 1, 0, 1, 1], + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/transformer/attention.py b/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ab415d3d66c5c1dd1f45a9af023c1f1af2feaf5c --- /dev/null +++ b/transformer/attention.py @@ -0,0 +1,160 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Instrumented attention layer (forked from the Haiku library implementation). +""" + +from typing import Optional +import warnings + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +@chex.dataclass +class AttentionOutput: + out: jax.Array # [..., T', D'] + logits: jax.Array # [..., H, T', T] + + +class MultiHeadAttention(hk.Module): + """Multi-headed attention (MHA) module. + + This module is intended for attending over sequences of vectors. + + Rough sketch: + - Compute keys (K), queries (Q), and values (V) as projections of inputs. + - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). + - Output is another projection of WV^T. + + For more detail, see the original Transformer paper: + "Attention is all you need" https://arxiv.org/abs/1706.03762. + + Glossary of shapes: + - T: Sequence length. + - D: Vector (embedding) size. + - H: Number of attention heads. + """ + + def __init__( + self, + num_heads: int, + key_size: int, + # TODO(b/240019186): Remove `w_init_scale`. + w_init_scale: Optional[float] = None, + *, + w_init: Optional[hk.initializers.Initializer] = None, + value_size: Optional[int] = None, + model_size: Optional[int] = None, + name: Optional[str] = None, + ): + """Initialises the module. + + Args: + num_heads: Number of independent attention heads (H). + key_size: The size of keys (K) and queries used for attention. + w_init_scale: DEPRECATED. Please use w_init instead. + w_init: Initialiser for weights in the linear map. + value_size: Optional size of the value projection (V). If None, defaults + to the key size (K). + model_size: Optional size of the output embedding (D'). If None, defaults + to the key size multiplied by the number of heads (K * H). + name: Optional name for this module. + """ + super().__init__(name=name) + self.num_heads = num_heads + self.key_size = key_size + self.value_size = value_size or key_size + self.model_size = model_size or key_size * num_heads + + # Backwards-compatibility for w_init_scale. + if w_init_scale is not None: + warnings.warn( + "w_init_scale is deprecated; please pass an explicit weight " + "initialiser instead.", DeprecationWarning) + if w_init and w_init_scale: + raise ValueError("Please provide only `w_init`, not `w_init_scale`.") + if w_init is None and w_init_scale is None: + raise ValueError("Please provide a weight initializer: `w_init`.") + if w_init is None: + w_init = hk.initializers.VarianceScaling(w_init_scale) + self.w_init = w_init + + def __call__( + self, + query: jnp.ndarray, + key: jnp.ndarray, + value: jnp.ndarray, + mask: Optional[jnp.ndarray] = None, + ) -> AttentionOutput: + """Computes (optionally masked) MHA with queries, keys & values. + + This module broadcasts over zero or more 'batch-like' leading dimensions. + + Args: + query: Embeddings sequence used to compute queries; shape [..., T', D_q]. + key: Embeddings sequence used to compute keys; shape [..., T, D_k]. + value: Embeddings sequence used to compute values; shape [..., T, D_v]. + mask: Optional mask applied to attention weights; shape [..., H=1, T', T]. + + Returns: + A new sequence of embeddings, consisting of a projection of the + attention-weighted value projections; shape [..., T', D']. + """ + + # In shape hints below, we suppress the leading dims [...] for brevity. + # Hence e.g. [A, B] should be read in every case as [..., A, B]. + *leading_dims, sequence_length, _ = query.shape + projection = self._linear_projection + + # Compute key/query/values (overload K/Q/V to denote the respective sizes). + query_heads = projection(query, self.key_size, "query") # [T', H, Q=K] + key_heads = projection(key, self.key_size, "key") # [T, H, K] + value_heads = projection(value, self.value_size, "value") # [T, H, V] + + # Compute attention weights. + attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) + attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) + if mask is not None: + if mask.ndim != attn_logits.ndim: + raise ValueError( + f"Mask dimensionality {mask.ndim} must match logits dimensionality " + f"{attn_logits.ndim}.") + attn_logits = jnp.where(mask, attn_logits, -1e30) + attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] + + # Weight the values by the attention and flatten the head vectors. + attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) + attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] + + # Apply another projection to get the final embeddings. + final_projection = hk.Linear(self.model_size, w_init=self.w_init) + return AttentionOutput( + out=final_projection(attn), + logits=attn_logits, + ) + + @hk.transparent + def _linear_projection( + self, + x: jnp.ndarray, + head_size: int, + name: Optional[str] = None, + ) -> jnp.ndarray: + y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x) + *leading_dims, _ = x.shape + return y.reshape((*leading_dims, self.num_heads, head_size)) diff --git a/transformer/compressed_model.py b/transformer/compressed_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b612f80728c6aae0c887e6a695e746a092014599 --- /dev/null +++ b/transformer/compressed_model.py @@ -0,0 +1,185 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Modified transformer to learn a linear compression of the residual stream. + +CompressedTransformer adds three arguments compared to Transformer: +- embedding_size: the size of the compressed residual stream. +- unembed_at_every_layer: whether to apply the unembedding before applying + attention and MLP layers +- return_activations: whether to return all model activations rather than just + the outputs +""" + +import collections +import dataclasses +from typing import Optional + +import haiku as hk +import jax +import numpy as np + +from tracr.transformer import attention +from tracr.transformer import model + + +@dataclasses.dataclass +class CompressedTransformer(hk.Module): + """A transformer stack with linearly compressed residual stream.""" + + config: model.TransformerConfig + name: Optional[str] = None + + def __call__( + self, + embeddings: jax.Array, # [B, T, D] + mask: jax.Array, # [B, T] + *, + use_dropout: bool = True, + embedding_size: Optional[int] = None, + unembed_at_every_layer: bool = False, + ) -> model.TransformerOutput: # [B, T, D] + """Transforms input embedding sequences to output embedding sequences. + + Args: + embeddings: Input embeddings to pass through the model. + mask: Boolean mask to restrict the inputs the model uses. + use_dropout: Turns dropout on/off. + embedding_size: Dimension to compress the residual stream to. + unembed_at_every_layer: Whether to unembed the residual stream when + reading the input for every layer (keeping the layer input sizes) or to + only unembed before the model output (compressing the layer inputs). + + Returns: + The outputs of the forward pass through the transformer. + """ + + def layer_norm(x: jax.Array) -> jax.Array: + """Applies a unique LayerNorm to x with default settings.""" + if self.config.layer_norm: + return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) + return x + + initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers) + dropout_rate = self.config.dropout_rate if use_dropout else 0. + _, seq_len, model_size = embeddings.shape + + # To compress the model, we multiply with a matrix W when reading from + # the residual stream, and with W^T when writing to the residual stream. + if embedding_size is not None: + # [to_size, from_size] + w_emb = hk.get_parameter( + "w_emb", (embedding_size, model_size), + init=hk.initializers.RandomNormal()) + + write_to_residual = lambda x: x @ w_emb.T + read_from_residual = lambda x: x @ w_emb + + if not unembed_at_every_layer: + model_size = embedding_size + else: + write_to_residual = lambda x: x + read_from_residual = lambda x: x + + # Compute causal mask for autoregressive sequence modelling. + mask = mask[:, None, None, :] # [B, H=1, T'=1, T] + mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T] + + if self.config.causal: + causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T] + causal_mask = np.tril(causal_mask) + mask = mask * causal_mask # [B, H=1, T, T] + + # Set up activation collection. + collected = collections.defaultdict(list) + + def collect(**kwargs): + for k, v in kwargs.items(): + collected[k].append(v) + + residual = write_to_residual(embeddings) + + for layer in range(self.config.num_layers): + with hk.experimental.name_scope(f"layer_{layer}"): + # First the attention block. + attn_block = attention.MultiHeadAttention( + num_heads=self.config.num_heads, + key_size=self.config.key_size, + model_size=model_size, + w_init=initializer, + name="attn") + + attn_in = residual + if unembed_at_every_layer: + attn_in = read_from_residual(attn_in) + attn_in = layer_norm(attn_in) + attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask) + attn_out, attn_logits = attn_out.out, attn_out.logits + if dropout_rate > 0: + attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out) + + if unembed_at_every_layer: + collect(layer_outputs=attn_out, attn_logits=attn_logits) + else: + collect( + layer_outputs=read_from_residual(attn_out), + attn_logits=attn_logits, + ) + + if unembed_at_every_layer: + attn_out = write_to_residual(attn_out) + residual = residual + attn_out + + collect(residuals=residual) + + # Then the dense block. + with hk.experimental.name_scope("mlp"): + dense_block = hk.Sequential([ + hk.Linear( + self.config.mlp_hidden_size, + w_init=initializer, + name="linear_1"), + self.config.activation_function, + hk.Linear(model_size, w_init=initializer, name="linear_2"), + ]) + + dense_in = residual + if unembed_at_every_layer: + dense_in = read_from_residual(dense_in) + dense_in = layer_norm(dense_in) + dense_out = dense_block(dense_in) + if dropout_rate > 0: + dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out) + + if unembed_at_every_layer: + collect(layer_outputs=dense_out) + else: + collect(layer_outputs=read_from_residual(dense_out)) + + if unembed_at_every_layer: + dense_out = write_to_residual(dense_out) + residual = residual + dense_out + + collect(residuals=residual) + + output = read_from_residual(residual) + output = layer_norm(output) + + return model.TransformerOutput( + layer_outputs=collected["layer_outputs"], + residuals=collected["residuals"], + attn_logits=collected["attn_logits"], + output=output, + input_embeddings=embeddings, + ) diff --git a/transformer/compressed_model_test.py b/transformer/compressed_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e8648815363d1c7a235b7ec1fbb0fc84f1041304 --- /dev/null +++ b/transformer/compressed_model_test.py @@ -0,0 +1,318 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transformer.model.""" + +from absl.testing import absltest +from absl.testing import parameterized +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from tracr.transformer import compressed_model +from tracr.transformer import model + + +class CompressedTransformerTest(parameterized.TestCase): + + def _check_layer_naming(self, params): + # Modules should be named for example + # For MLPs: "compressed_transformer/layer_{i}/mlp/linear_1" + # For Attention: "compressed_transformer/layer_{i}/attn/key" + # For Layer Norm: "compressed_transformer/layer_{i}/layer_norm" + for key in params.keys(): + levels = key.split("/") + self.assertEqual(levels[0], "compressed_transformer") + if len(levels) == 1: + self.assertEqual(list(params[key].keys()), ["w_emb"]) + continue + if levels[1].startswith("layer_norm"): + continue # output layer norm + self.assertStartsWith(levels[1], "layer") + if levels[2] == "mlp": + self.assertIn(levels[3], {"linear_1", "linear_2"}) + elif levels[2] == "attn": + self.assertIn(levels[3], {"key", "query", "value", "linear"}) + else: + self.assertStartsWith(levels[2], "layer_norm") + + def _zero_mlps(self, params): + for module in params: + if "mlp" in module: + for param in params[module]: + params[module][param] = jnp.zeros_like(params[module][param]) + return params + + @parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False)) + def test_layer_norm(self, layer_norm): + # input = [1, 1, 1, 1] + # If layer norm is used, this should give all-0 output for a freshly + # initialized model because LN will subtract the mean after each layer. + # Else we expect non-zero outputs. + + @hk.transform + def forward(emb, mask): + transformer = compressed_model.CompressedTransformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + layer_norm=layer_norm)) + return transformer(emb, mask).output + + seq_len = 4 + emb = jnp.ones((1, seq_len, 1)) + mask = jnp.ones((1, seq_len)) + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + out = forward.apply(params, next(rng), emb, mask) + + self._check_layer_naming(params) + if layer_norm: + np.testing.assert_allclose(out, 0) + else: + self.assertFalse(np.allclose(out, 0)) + + @parameterized.parameters(dict(causal=True), dict(causal=False)) + def test_causal_attention(self, causal): + # input = [0, random, random, random] + # mask = [1, 0, 1, 1] + # For causal attention the second token can only attend to the first one, so + # it should be the same. For non-causal attention all tokens should change. + + @hk.transform + def forward(emb, mask): + transformer = compressed_model.CompressedTransformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + layer_norm=False, + causal=causal)) + return transformer(emb, mask).output + + seq_len = 4 + emb = np.random.random((1, seq_len, 1)) + emb[:, 0, :] = 0 + mask = np.array([[1, 0, 1, 1]]) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + params = self._zero_mlps(params) + out = forward.apply(params, next(rng), emb, mask) + + self._check_layer_naming(params) + if causal: + self.assertEqual(0, out[0, 0, 0]) + self.assertEqual(emb[0, 1, 0], out[0, 1, 0]) + else: + self.assertNotEqual(0, out[0, 0, 0]) + self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0]) + self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0]) + self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0]) + + def test_setting_activation_function_to_zero(self): + # An activation function that always returns zeros should result in the + # same model output as setting all MLP weights to zero. + + @hk.transform + def forward_zero(emb, mask): + transformer = compressed_model.CompressedTransformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False, + activation_function=jnp.zeros_like)) + return transformer(emb, mask).output + + @hk.transform + def forward(emb, mask): + transformer = compressed_model.CompressedTransformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False, + activation_function=jax.nn.gelu)) + return transformer(emb, mask).output + + seq_len = 4 + emb = np.random.random((1, seq_len, 1)) + mask = np.ones((1, seq_len)) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + params_no_mlps = self._zero_mlps(params) + + out_zero_activation = forward_zero.apply(params, next(rng), emb, mask) + out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask) + + self._check_layer_naming(params) + np.testing.assert_allclose(out_zero_activation, out_no_mlps) + self.assertFalse(np.allclose(out_zero_activation, 0)) + + def test_not_setting_embedding_size_produces_same_output_as_default_model( + self): + config = model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False) + + @hk.without_apply_rng + @hk.transform + def forward_model(emb, mask): + return model.Transformer(config)(emb, mask).output + + @hk.without_apply_rng + @hk.transform + def forward_superposition(emb, mask): + return compressed_model.CompressedTransformer(config)(emb, mask).output + + seq_len = 4 + emb = np.random.random((1, seq_len, 1)) + mask = np.ones((1, seq_len)) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward_model.init(next(rng), emb, mask) + params_superposition = { + k.replace("transformer", "compressed_transformer"): v + for k, v in params.items() + } + + out_model = forward_model.apply(params, emb, mask) + out_superposition = forward_superposition.apply(params_superposition, emb, + mask) + + self._check_layer_naming(params_superposition) + np.testing.assert_allclose(out_model, out_superposition) + + @parameterized.parameters( + dict(embedding_size=2, unembed_at_every_layer=True), + dict(embedding_size=2, unembed_at_every_layer=False), + dict(embedding_size=6, unembed_at_every_layer=True), + dict(embedding_size=6, unembed_at_every_layer=False)) + def test_embbeding_size_produces_correct_shape_of_residuals_and_layer_outputs( + self, embedding_size, unembed_at_every_layer): + + @hk.transform + def forward(emb, mask): + transformer = compressed_model.CompressedTransformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False)) + return transformer( + emb, + mask, + embedding_size=embedding_size, + unembed_at_every_layer=unembed_at_every_layer, + ) + + seq_len = 4 + model_size = 16 + + emb = np.random.random((1, seq_len, model_size)) + mask = np.ones((1, seq_len)) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + activations = forward.apply(params, next(rng), emb, mask) + + self._check_layer_naming(params) + + for residual in activations.residuals: + self.assertEqual(residual.shape, (1, seq_len, embedding_size)) + + for layer_output in activations.layer_outputs: + self.assertEqual(layer_output.shape, (1, seq_len, model_size)) + + @parameterized.parameters( + dict(model_size=2, unembed_at_every_layer=True), + dict(model_size=2, unembed_at_every_layer=False), + dict(model_size=6, unembed_at_every_layer=True), + dict(model_size=6, unembed_at_every_layer=False)) + def test_identity_embedding_produces_same_output_as_standard_model( + self, model_size, unembed_at_every_layer): + + config = model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False) + + @hk.without_apply_rng + @hk.transform + def forward_model(emb, mask): + return model.Transformer(config)(emb, mask).output + + @hk.without_apply_rng + @hk.transform + def forward_superposition(emb, mask): + return compressed_model.CompressedTransformer(config)( + emb, + mask, + embedding_size=model_size, + unembed_at_every_layer=unembed_at_every_layer).output + + seq_len = 4 + emb = np.random.random((1, seq_len, model_size)) + mask = np.ones((1, seq_len)) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward_model.init(next(rng), emb, mask) + params_superposition = { + k.replace("transformer", "compressed_transformer"): v + for k, v in params.items() + } + params_superposition["compressed_transformer"] = { + "w_emb": jnp.identity(model_size) + } + + out_model = forward_model.apply(params, emb, mask) + out_superposition = forward_superposition.apply(params_superposition, emb, + mask) + + self._check_layer_naming(params_superposition) + np.testing.assert_allclose(out_model, out_superposition) + + +if __name__ == "__main__": + absltest.main() diff --git a/transformer/encoder.py b/transformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..98f631bff9a5e79fb2769fc1297b5ed11a82b201 --- /dev/null +++ b/transformer/encoder.py @@ -0,0 +1,135 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Basic encoder for inputs with a fixed vocabulary.""" + +import abc +from typing import Any, Sequence, Optional + +from tracr.craft import bases + + +class Encoder(abc.ABC): + """Encodes a list of tokens into a list of inputs for a transformer model. + + The abstract class does not make assumptions on the input and output types, + and we have different encoders for different input types. + """ + + @abc.abstractmethod + def encode(self, inputs: list[Any]) -> list[Any]: + return list() + + @abc.abstractmethod + def decode(self, encodings: list[Any]) -> list[Any]: + return list() + + @property + def pad_token(self) -> Optional[str]: + return None + + @property + def bos_token(self) -> Optional[str]: + return None + + @property + def pad_encoding(self) -> Optional[int]: + return None + + @property + def bos_encoding(self) -> Optional[int]: + return None + + +class NumericalEncoder(Encoder): + """Encodes numerical variables (simply using the identity mapping).""" + + def encode(self, inputs: list[float]) -> list[float]: + return inputs + + def decode(self, encodings: list[float]) -> list[float]: + return encodings + + +class CategoricalEncoder(Encoder): + """Encodes categorical variables with a fixed vocabulary.""" + + def __init__( + self, + basis: Sequence[bases.BasisDirection], + enforce_bos: bool = False, + bos_token: Optional[str] = None, + pad_token: Optional[str] = None, + max_seq_len: Optional[int] = None, + ): + """Initialises. If enforce_bos is set, ensures inputs start with it.""" + if enforce_bos and not bos_token: + raise ValueError("BOS token must be specified if enforcing BOS.") + + self.encoding_map = {} + for i, direction in enumerate(basis): + val = direction.value + self.encoding_map[val] = i + + if bos_token and bos_token not in self.encoding_map: + raise ValueError("BOS token missing in encoding.") + + if pad_token and pad_token not in self.encoding_map: + raise ValueError("PAD token missing in encoding.") + + self.enforce_bos = enforce_bos + self._bos_token = bos_token + self._pad_token = pad_token + self._max_seq_len = max_seq_len + + def encode(self, inputs: list[bases.Value]) -> list[int]: + if self.enforce_bos and inputs[0] != self.bos_token: + raise ValueError("First input token must be BOS token. " + f"Should be '{self.bos_token}', but was '{inputs[0]}'.") + if missing := set(inputs) - set(self.encoding_map.keys()): + raise ValueError(f"Inputs {missing} not found in encoding ", + self.encoding_map.keys()) + if self._max_seq_len is not None and len(inputs) > self._max_seq_len: + raise ValueError(f"{inputs=} are longer than the maximum " + f"sequence length {self._max_seq_len}") + + return [self.encoding_map[x] for x in inputs] + + def decode(self, encodings: list[int]) -> list[bases.Value]: + """Recover the tokens that corresponds to `ids`. Inverse of __call__.""" + decoding_map = {val: key for key, val in self.encoding_map.items()} + if missing := set(encodings) - set(decoding_map.keys()): + raise ValueError(f"Inputs {missing} not found in decoding map ", + decoding_map.keys()) + return [decoding_map[x] for x in encodings] + + @property + def vocab_size(self) -> int: + return len(self.encoding_map) + + @property + def bos_token(self) -> Optional[str]: + return self._bos_token + + @property + def pad_token(self) -> Optional[str]: + return self._pad_token + + @property + def bos_encoding(self) -> Optional[int]: + return None if self.bos_token is None else self.encoding_map[self.bos_token] + + @property + def pad_encoding(self) -> Optional[int]: + return None if self.pad_token is None else self.encoding_map[self.pad_token] diff --git a/transformer/encoder_test.py b/transformer/encoder_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8538f53ff2154f28e3acc6054ca1c17fd346246a --- /dev/null +++ b/transformer/encoder_test.py @@ -0,0 +1,123 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transformer.encoder.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.craft import bases +from tracr.transformer import encoder + +_BOS_TOKEN = "bos_encoder_test" +_PAD_TOKEN = "pad_encoder_test" + + +class CategoricalEncoderTest(parameterized.TestCase): + + def test_encode_raises_value_error_if_input_doesnt_start_with_bos(self): + vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN}) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) + with self.assertRaisesRegex(ValueError, + r"^.*First input token must be BOS token.*$"): + basic_encoder.encode([1, 1, 1]) + + def test_encode_raises_value_error_if_input_not_in_vocab(self): + vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN}) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) + with self.assertRaisesRegex(ValueError, + r"^.*Inputs .* not found in encoding.*$"): + basic_encoder.encode([_BOS_TOKEN, 1, 2, 3, 4]) + + def test_decode_raises_value_error_if_id_outside_of_vocab_size(self): + vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, _BOS_TOKEN}) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) + with self.assertRaisesRegex(ValueError, + r"^.*Inputs .* not found in decoding map.*$"): + basic_encoder.decode([0, 1, 2, 3]) + + def test_encoder_raises_value_error_if_bos_not_in_basis(self): + vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3}) + with self.assertRaisesRegex(ValueError, + r"^.*BOS token missing in encoding.*$"): + unused_basic_encoder = encoder.CategoricalEncoder( + vs.basis, bos_token=_BOS_TOKEN) + + def test_encoder_raises_value_error_if_pad_not_in_basis(self): + vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3}) + with self.assertRaisesRegex(ValueError, + r"^.*PAD token missing in encoding.*$"): + unused_basic_encoder = encoder.CategoricalEncoder( + vs.basis, pad_token=_PAD_TOKEN) + + def test_encoder_encodes_bos_and_pad_tokens_as_expected(self): + vs = bases.VectorSpaceWithBasis.from_values( + "input", {1, 2, 3, _BOS_TOKEN, _PAD_TOKEN}) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) + self.assertEqual( + basic_encoder.encode([_BOS_TOKEN, _PAD_TOKEN]), + [basic_encoder.bos_encoding, basic_encoder.pad_encoding]) + + @parameterized.parameters([ + dict( + vocab={1, 2, 3, _BOS_TOKEN}, # lexicographic order + inputs=[_BOS_TOKEN, 3, 2, 1], + expected=[3, 2, 1, 0]), + dict( + vocab={"a", "b", _BOS_TOKEN, "c"}, # lexicographic order + inputs=[_BOS_TOKEN, "b", "b", "c"], + expected=[2, 1, 1, 3]), + ]) + def test_tokens_are_encoded_in_lexicographic_order(self, vocab, inputs, + expected): + # Expect encodings to be assigned to ids according to a lexicographic + # ordering of the vocab + vs = bases.VectorSpaceWithBasis.from_values("input", vocab) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) + encodings = basic_encoder.encode(inputs) + self.assertEqual(encodings, expected) + + @parameterized.parameters([ + dict(vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, expected=5), + dict(vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b"}, expected=4), + ]) + def test_vocab_size_has_expected_value(self, vocab, expected): + vs = bases.VectorSpaceWithBasis.from_values("input", vocab) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) + self.assertEqual(basic_encoder.vocab_size, expected) + + @parameterized.parameters([ + dict( + vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, inputs=[_BOS_TOKEN, 3, 2, + 1]), + dict( + vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b", "c"}, + inputs=[_BOS_TOKEN, "b", "b", "c"]), + ]) + def test_decode_inverts_encode(self, vocab, inputs): + vs = bases.VectorSpaceWithBasis.from_values("input", vocab) + basic_encoder = encoder.CategoricalEncoder( + vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) + encodings = basic_encoder.encode(inputs) + recovered = basic_encoder.decode(encodings) + self.assertEqual(recovered, inputs) + + +if __name__ == "__main__": + absltest.main() diff --git a/transformer/model.py b/transformer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..79f2519bc07ab820a70d4c956ad3d7c30a36efbe --- /dev/null +++ b/transformer/model.py @@ -0,0 +1,199 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Didactic example of an autoregressive Transformer-based language model. + +Glossary of shapes: +- B: Batch size. +- T: Sequence length. +- D: Model embedding size. +- H: Number of attention heads. +- V: Vocabulary size. + +Forked from: haiku.examples.transformer.model +""" + +import collections +import dataclasses +from typing import Callable, Optional + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from tracr.transformer import attention + +# hk.Modules are not always callable: github.com/deepmind/dm-haiku/issues/52 +# Ideally, we'd want a type: +# CallableHaikuModule = Intersection[Callable[..., jax.Array], hk.Module] +# But Intersection does not exist (yet): github.com/python/typing/issues/213 +CallableHaikuModule = Callable[..., jax.Array] + + +@chex.dataclass +class TransformerOutput: + layer_outputs: list[jax.Array] # [B, T, D] + residuals: list[jax.Array] # [B, T, D] + attn_logits: list[jax.Array] # [B, H, T, T] + output: jax.Array # [B, T, D] + input_embeddings: jax.Array # [B, T, D] + + +@dataclasses.dataclass +class TransformerConfig: + num_heads: int + num_layers: int + key_size: int + mlp_hidden_size: int + dropout_rate: float + activation_function: Callable[[jax.Array], jax.Array] = jax.nn.gelu + layer_norm: bool = True + causal: bool = False + + +@dataclasses.dataclass +class Transformer(hk.Module): + """A transformer stack.""" + + config: TransformerConfig + name: Optional[str] = None + + def __call__( + self, + embeddings: jax.Array, # [B, T, D] + mask: jax.Array, # [B, T] + *, + use_dropout: bool = True, + ) -> TransformerOutput: + """Transforms input embedding sequences to output embedding sequences.""" + + def layer_norm(x: jax.Array) -> jax.Array: + """Applies a unique LayerNorm to x with default settings.""" + if self.config.layer_norm: + return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) + return x + + initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers) + dropout_rate = self.config.dropout_rate if use_dropout else 0. + _, seq_len, model_size = embeddings.shape + + # Compute causal mask for autoregressive sequence modelling. + mask = mask[:, None, None, :] # [B, H=1, T'=1, T] + mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T] + + if self.config.causal: + causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T] + causal_mask = np.tril(causal_mask) + mask = mask * causal_mask # [B, H=1, T, T] + + # Set up activation collection. + collected = collections.defaultdict(list) + + def collect(**kwargs): + for k, v in kwargs.items(): + collected[k].append(v) + + residual = embeddings + for layer in range(self.config.num_layers): + with hk.experimental.name_scope(f"layer_{layer}"): + # First the attention block. + attn_block = attention.MultiHeadAttention( + num_heads=self.config.num_heads, + key_size=self.config.key_size, + model_size=model_size, + w_init=initializer, + name="attn") + attn_in = layer_norm(residual) + attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask) + attn_out, attn_logits = attn_out.out, attn_out.logits + if dropout_rate > 0: + attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out) + residual = residual + attn_out + + collect( + residuals=residual, layer_outputs=attn_out, attn_logits=attn_logits) + + # Then the dense block. + with hk.experimental.name_scope("mlp"): + dense_block = hk.Sequential([ + hk.Linear( + self.config.mlp_hidden_size, + w_init=initializer, + name="linear_1"), + self.config.activation_function, + hk.Linear(model_size, w_init=initializer, name="linear_2"), + ]) + dense_in = layer_norm(residual) + dense_out = dense_block(dense_in) + if dropout_rate > 0: + dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out) + residual = residual + dense_out + + collect(residuals=residual, layer_outputs=dense_out) + + return TransformerOutput( + residuals=collected["residuals"], + layer_outputs=collected["layer_outputs"], + attn_logits=collected["attn_logits"], + output=layer_norm(residual), + input_embeddings=embeddings, + ) + + +@chex.dataclass +class CompiledTransformerModelOutput: + transformer_output: TransformerOutput + unembedded_output: jax.Array # [B, T] + + +@dataclasses.dataclass +class CompiledTransformerModel(hk.Module): + """A transformer model with one-hot embeddings.""" + transformer: Transformer + token_embed: CallableHaikuModule + position_embed: CallableHaikuModule + unembed: CallableHaikuModule + use_unembed_argmax: bool + pad_token: Optional[int] = None + + def embed(self, tokens: jax.Array) -> jax.Array: + token_embeddings = self.token_embed(tokens) + positional_embeddings = self.position_embed(jnp.indices(tokens.shape)[-1]) + return token_embeddings + positional_embeddings # [B, T, D] + + def __call__( + self, + tokens: jax.Array, + use_dropout: bool = True, + ) -> CompiledTransformerModelOutput: + """Embed tokens, pass through model, and unembed output.""" + if self.pad_token is None: + input_mask = jnp.ones_like(tokens) + else: + input_mask = (tokens != self.pad_token) + input_embeddings = self.embed(tokens) + + transformer_output = self.transformer( + input_embeddings, + input_mask, + use_dropout=use_dropout, + ) + return CompiledTransformerModelOutput( + transformer_output=transformer_output, + unembedded_output=self.unembed( + transformer_output.output, + use_unembed_argmax=self.use_unembed_argmax, + ), + ) diff --git a/transformer/model_test.py b/transformer/model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..76b2e7592b9a00f2af5d152072cc612da1a4cc32 --- /dev/null +++ b/transformer/model_test.py @@ -0,0 +1,275 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transformer.model.""" + +from absl.testing import absltest +from absl.testing import parameterized +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from tracr.transformer import model + + +class TransformerTest(parameterized.TestCase): + + def _check_layer_naming(self, params): + # Modules should be named for example + # For MLPs: "transformer/layer_{i}/mlp/linear_1" + # For Attention: "transformer/layer_{i}/attn/key" + # For Layer Norm: "transformer/layer_{i}/layer_norm" + for key in params.keys(): + levels = key.split("/") + self.assertEqual(levels[0], "transformer") + if levels[1].startswith("layer_norm"): + continue # output layer norm + self.assertStartsWith(levels[1], "layer") + if levels[2] == "mlp": + self.assertIn(levels[3], {"linear_1", "linear_2"}) + elif levels[2] == "attn": + self.assertIn(levels[3], {"key", "query", "value", "linear"}) + else: + self.assertStartsWith(levels[2], "layer_norm") + + def _zero_mlps(self, params): + for module in params: + if "mlp" in module: + for param in params[module]: + params[module][param] = jnp.zeros_like(params[module][param]) + return params + + @parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False)) + def test_layer_norm(self, layer_norm): + # input = [1, 1, 1, 1] + # If layer norm is used, this should give all-0 output for a freshly + # initialized model because LN will subtract the mean after each layer. + # Else we expect non-zero outputs. + + @hk.transform + def forward(emb, mask): + transformer = model.Transformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + layer_norm=layer_norm)) + return transformer(emb, mask).output + + seq_len = 4 + emb = jnp.ones((1, seq_len, 1)) + mask = jnp.ones((1, seq_len)) + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + out = forward.apply(params, next(rng), emb, mask) + + self._check_layer_naming(params) + if layer_norm: + np.testing.assert_allclose(out, 0) + else: + self.assertFalse(np.allclose(out, 0)) + + @parameterized.parameters(dict(causal=True), dict(causal=False)) + def test_causal_attention(self, causal): + # input = [0, random, random, random] + # mask = [1, 0, 1, 1] + # For causal attention the second token can only attend to the first one, so + # it should be the same. For non-causal attention all tokens should change. + + @hk.transform + def forward(emb, mask): + transformer = model.Transformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + layer_norm=False, + causal=causal)) + return transformer(emb, mask).output + + seq_len = 4 + emb = np.random.random((1, seq_len, 1)) + emb[:, 0, :] = 0 + mask = np.array([[1, 0, 1, 1]]) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + params = self._zero_mlps(params) + out = forward.apply(params, next(rng), emb, mask) + + self._check_layer_naming(params) + if causal: + self.assertEqual(0, out[0, 0, 0]) + self.assertEqual(emb[0, 1, 0], out[0, 1, 0]) + else: + self.assertNotEqual(0, out[0, 0, 0]) + self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0]) + self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0]) + self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0]) + + def test_setting_activation_function_to_zero(self): + # An activation function that always returns zeros should result in the + # same model output as setting all MLP weights to zero. + + @hk.transform + def forward_zero(emb, mask): + transformer = model.Transformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False, + activation_function=jnp.zeros_like)) + return transformer(emb, mask).output + + @hk.transform + def forward(emb, mask): + transformer = model.Transformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False, + activation_function=jax.nn.gelu)) + return transformer(emb, mask).output + + seq_len = 4 + emb = np.random.random((1, seq_len, 1)) + mask = np.ones((1, seq_len)) + emb, mask = jnp.array(emb), jnp.array(mask) + + rng = hk.PRNGSequence(1) + params = forward.init(next(rng), emb, mask) + params_no_mlps = self._zero_mlps(params) + + out_zero_activation = forward_zero.apply(params, next(rng), emb, mask) + out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask) + + self._check_layer_naming(params) + np.testing.assert_allclose(out_zero_activation, out_no_mlps) + self.assertFalse(np.allclose(out_zero_activation, 0)) + + +class CompiledTransformerModelTest(parameterized.TestCase): + + def _get_one_hot_embed_unembed(self, vocab_size, max_seq_len): + # Embeds tokens as one-hot into the first `vocab_size` dimensions + token_embed = hk.Embed( + embedding_matrix=jnp.block( + [jnp.eye(vocab_size), + jnp.zeros((vocab_size, max_seq_len))])) + + # Embeds positions as one-hot into the last `max_seq_len` dimensions + position_embed = hk.Embed( + embedding_matrix=jnp.block( + [jnp.zeros((max_seq_len, vocab_size)), + jnp.eye(max_seq_len)])) + + class Unembed(hk.Module): + + def __call__(self, embeddings): + return jnp.argmax(embeddings[:, :, :vocab_size], axis=-1) + + return token_embed, position_embed, Unembed() + + def test_embedding_gives_desired_result(self): + tokens = jnp.array([[1, 2, 3]]) + vocab_size, max_seq_len, pad_token = 5, 5, 0 + + expected_embeddings = jnp.array([[[0, 1, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 1, 0, 0]]]) + + @hk.transform + def embed(tokens): + transformer = model.Transformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False, + activation_function=jax.nn.gelu)) + token_embed, position_embed, unembed = self._get_one_hot_embed_unembed( + vocab_size, max_seq_len) + compiled_model = model.CompiledTransformerModel( + transformer=transformer, + token_embed=token_embed, + position_embed=position_embed, + unembed=unembed, + use_unembed_argmax=True, + pad_token=pad_token) + return compiled_model.embed(tokens) + + rng = hk.PRNGSequence(1) + params = embed.init(next(rng), tokens) + embeddings = embed.apply(params, next(rng), tokens) + + np.testing.assert_allclose(embeddings, expected_embeddings) + + def test_embedding_then_unembedding_gives_same_tokens(self): + tokens = jnp.array([[1, 2, 3], [4, 5, 6], [3, 2, 4]]) + vocab_size, max_seq_len, pad_token = 10, 5, 0 + + @hk.transform + def embed_unembed(tokens): + transformer = model.Transformer( + model.TransformerConfig( + num_heads=2, + num_layers=2, + key_size=5, + mlp_hidden_size=64, + dropout_rate=0., + causal=False, + layer_norm=False, + activation_function=jax.nn.gelu)) + token_embed, position_embed, unembed = self._get_one_hot_embed_unembed( + vocab_size, max_seq_len) + compiled_model = model.CompiledTransformerModel( + transformer=transformer, + token_embed=token_embed, + position_embed=position_embed, + unembed=unembed, + use_unembed_argmax=True, + pad_token=pad_token) + embeddings = compiled_model.embed(tokens) + unembeddings = compiled_model.unembed(embeddings) + return embeddings, unembeddings + + rng = hk.PRNGSequence(1) + params = embed_unembed.init(next(rng), tokens) + embeddings, unembeddings = embed_unembed.apply(params, next(rng), tokens) + + self.assertEqual( + embeddings.shape, + (tokens.shape[0], tokens.shape[1], vocab_size + max_seq_len)) + + np.testing.assert_allclose(unembeddings, tokens) + + +if __name__ == "__main__": + absltest.main() diff --git a/utils/debugging.py b/utils/debugging.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d06b08ed7376a1065cad87df2f20d5926239a6 --- /dev/null +++ b/utils/debugging.py @@ -0,0 +1,28 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Useful helpers for model debugging.""" + + +def print_arrays(arrays, labels=None, colwidth=12): + """Pretty-prints a list of [1, T, D] arrays.""" + if labels is not None: + print(" |".join(labels)) + widths = [len(l) for l in labels] + else: + widths = [colwidth] * len(arrays[0].shape[-1]) + for layer in arrays: + print("=" * (colwidth + 1) * layer.shape[1]) + for row in layer[0]: + print(" |".join([f"{x:<{width}.2f}" for x, width in zip(row, widths)])) diff --git a/utils/errors.py b/utils/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ce0254cd0069a127713c26924c25238aeb2608 --- /dev/null +++ b/utils/errors.py @@ -0,0 +1,35 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers for handling errors in user-provided functions.""" + +import functools +import logging +from typing import Any, Callable + + +def ignoring_arithmetic_errors(fun: Callable[..., Any]) -> Callable[..., Any]: + """Makes fun return None instead of raising ArithmeticError.""" + + @functools.wraps(fun) + def fun_wrapped(*args): + try: + return fun(*args) + except ArithmeticError: + logging.warning( + "Encountered arithmetic error in function: for value %s. " + "Assuming this input will never occur.", str(args)) + return None + + return fun_wrapped diff --git a/utils/errors_test.py b/utils/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..52f5505d8740b47ed101eb53721d44efa2af986e --- /dev/null +++ b/utils/errors_test.py @@ -0,0 +1,58 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rasp.helper.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tracr.utils import errors + + +class FunIgnoreArithmeticErrorsTest(parameterized.TestCase): + + def test_ignoring_arithmetic_errors(self): + fun = lambda x: 1 / x + fun_ignore = errors.ignoring_arithmetic_errors(fun) + + with self.assertLogs(level="WARNING"): + res = fun_ignore(0) + self.assertIs(res, None) + + self.assertEqual(fun_ignore(1), 1) + self.assertEqual(fun_ignore(2), 0.5) + self.assertEqual(fun_ignore(-2), -0.5) + + def test_ignoring_arithmetic_errors_two_arguments(self): + fun = lambda x, y: 1 / x + 1 / y + fun_ignore = errors.ignoring_arithmetic_errors(fun) + + with self.assertLogs(level="WARNING"): + res = fun_ignore(0, 1) + self.assertIs(res, None) + + with self.assertLogs(level="WARNING"): + res = fun_ignore(0, 0) + self.assertIs(res, None) + + with self.assertLogs(level="WARNING"): + res = fun_ignore(1, 0) + self.assertIs(res, None) + + self.assertEqual(fun_ignore(1, 1), 2) + self.assertEqual(fun_ignore(1, 2), 1.5) + self.assertEqual(fun_ignore(-2, 2), 0) + + +if __name__ == "__main__": + absltest.main()