File size: 12,125 Bytes
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Add craft model blocks to graph of RASPExpr."""

from typing import Any, Callable, Optional

import networkx as nx
from tracr.compiler import nodes
from tracr.craft import bases
from tracr.craft.chamber import categorical_attn
from tracr.craft.chamber import categorical_mlp
from tracr.craft.chamber import numerical_mlp
from tracr.craft.chamber import selector_width
from tracr.rasp import rasp


def _transform_fun_to_basis_fun(
    fun: Callable[..., Any],
    output_direction_name: Optional[str] = None) -> Callable[..., Any]:
  """Transforms a function acting on values into one acting on directions."""

  def bases_fun(*args):
    values = [d.value for d in args]
    result = fun(*values)
    if output_direction_name:
      return bases.BasisDirection(output_direction_name, result)
    return result

  return bases_fun


def _check_selector_expression(expr, graph):
  """Check graph structure and encodings for an aggregate or selector width."""
  sel_expr = expr.selector

  # Check graph structure
  assert sel_expr.label in graph.predecessors(expr.label)
  assert sel_expr.keys.label in graph.predecessors(sel_expr.label)
  assert sel_expr.queries.label in graph.predecessors(sel_expr.label)

  if (not rasp.is_categorical(sel_expr.queries) or
      not rasp.is_categorical(sel_expr.keys)):
    raise ValueError("Selector keys and queries must be categorical.")


def add_craft_components_to_rasp_graph(
    graph: nx.DiGraph,
    bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"),
    one_dir: bases.BasisDirection = bases.BasisDirection("one"),
    causal: bool = False,
    mlp_exactness: float = 100,
) -> None:
  """Translates expressions to craft blocks and attaches them to the graph.

  Sets the `MODEL_BLOCK` attribute for all nodes in `graph`.

  Args:
    graph: RASP graph with  `VALUE_SET` but not `MODEL_BLOCK` attributes.
    bos_dir: Basis direction representing beginning of sequence (bos) token.
    one_dir: Auxiliary basis direction that must contain 1.
    causal: If True, marks attention blocks as causal.
    mlp_exactness: Controls the approximation of the MLP layers.

  Raises:
    ValueError: On invalid input (if `MODEL_BLOCK` is set already, or
      `VALUE_SET` is not set already)
    NotImplementedError: If the graph contains an unsupported expression.
  """
  one_space = bases.VectorSpaceWithBasis([one_dir])

  for node_id, node in graph.nodes.items():
    expr = node[nodes.EXPR]

    if not isinstance(expr, rasp.SOp):
      continue

    if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]:
      raise ValueError("Input graph cannot have model blocks set already.")
    if nodes.VALUE_SET not in node:
      raise ValueError(
          "Craft components can only be added after basis inference.")

    if expr is rasp.tokens or expr is rasp.indices:
      block = None
    elif isinstance(expr, rasp.Map):
      inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label]
      assert inner_expr.label in graph.predecessors(node_id)
      input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS])
      output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr):
        basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
        block = categorical_mlp.map_categorical_mlp(
            input_space=input_space,
            output_space=output_space,
            operation=basis_fun)
      elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr):
        block = categorical_mlp.map_categorical_to_numerical_mlp(
            input_space=input_space,
            output_space=output_space,
            operation=expr.f,
        )
      elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr):
        block = numerical_mlp.map_numerical_to_categorical_mlp(
            f=expr.f,
            input_space=input_space,
            output_space=output_space,
            input_value_set=inner_node[nodes.VALUE_SET],
            one_space=one_space,
            hidden_name=f"_hidden_{expr.label}_",
            large_number=mlp_exactness)
      elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr):
        block = numerical_mlp.map_numerical_mlp(
            f=expr.f,
            input_space=input_space,
            output_space=output_space,
            input_value_set=inner_node[nodes.VALUE_SET],
            one_space=one_space,
            hidden_name=f"_hidden_{expr.label}_",
            large_number=mlp_exactness)
      else:
        raise NotImplementedError("Map does no support "
                                  f"in_type '{inner_expr.type}' and"
                                  f" out_type '{expr.type}'!")

    elif isinstance(expr, rasp.SequenceMap):
      fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label]
      snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label]

      # Check graph structure
      assert fst_expr.label in graph.predecessors(node_id)
      assert snd_expr.label in graph.predecessors(node_id)

      fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS])
      snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS])
      out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      if (isinstance(expr, rasp.LinearSequenceMap) and
          not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))):
        raise NotImplementedError("Linear SequenceMap only supports numerical "
                                  "inputs/outputs.")
      elif (
          not isinstance(expr, rasp.LinearSequenceMap) and
          not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))):
        raise NotImplementedError("(Non-linear) SequenceMap only supports "
                                  "categorical inputs/outputs.")

      if isinstance(expr, rasp.LinearSequenceMap):
        assert len(fst_space.basis) == 1
        assert len(snd_space.basis) == 1
        assert len(out_space.basis) == 1
        block = numerical_mlp.linear_sequence_map_numerical_mlp(
            input1_basis_direction=fst_space.basis[0],
            input2_basis_direction=snd_space.basis[0],
            output_basis_direction=out_space.basis[0],
            input1_factor=expr.fst_fac,
            input2_factor=expr.snd_fac,
            hidden_name=f"_hidden_{expr.label}_")
      elif fst_space == snd_space:
        # It's okay to use the local variable expr.f because it is
        # only used within the same loop iteration to create the MLP.
        # pylint: disable=cell-var-from-loop
        basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x),
                                                expr.label)
        block = categorical_mlp.map_categorical_mlp(
            input_space=fst_space, output_space=out_space, operation=basis_fun)
      else:
        basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
        block = categorical_mlp.sequence_map_categorical_mlp(
            input1_space=fst_space,
            input2_space=snd_space,
            output_space=out_space,
            operation=basis_fun,
            one_space=one_space,
            hidden_name=f"_hidden_{expr.label}_")
    elif isinstance(expr, rasp.Aggregate):
      sel_expr: rasp.Select = expr.selector
      agg_expr: rasp.Aggregate = expr

      if not isinstance(sel_expr, rasp.Select):
        raise TypeError("Compiling composite Selectors is not supported. "
                        f"Got a {sel_expr}.")

      queries = graph.nodes[sel_expr.queries.label]
      keys = graph.nodes[sel_expr.keys.label]
      sop = graph.nodes[agg_expr.sop.label]

      _check_selector_expression(expr, graph)
      assert agg_expr.sop.label in graph.predecessors(node_id)
      if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr):
        raise ValueError(
            "sop encoding must match output encoding of the aggregate.")
      if rasp.is_categorical(agg_expr) and agg_expr.default is not None:
        raise ValueError("Default for a categorical aggregate must be None. "
                         f"Got {agg_expr.default}")
      if rasp.is_numerical(agg_expr) and agg_expr.default != 0:
        raise ValueError("Default for a numerical aggregate must be 0. "
                         f"Got {agg_expr.default}")

      bos_space = bases.VectorSpaceWithBasis([bos_dir])
      one_space = bases.VectorSpaceWithBasis([one_dir])
      query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
      key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
      value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS])
      output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      # Argument order is different in craft / transformers than RASP selectors
      def attn_basis_fn(query: bases.BasisDirection,
                        key: bases.BasisDirection) -> bool:
        # It's okay to use the local variable sel_expr because this function is
        # only used within the same loop iteration to create an attention head.
        # pylint: disable=cell-var-from-loop
        selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)
        return selector_basis_fn(key, query)

      block = categorical_attn.categorical_attn(
          query_space=query_space,
          key_space=key_space,
          value_space=value_space,
          output_space=output_space,
          bos_space=bos_space,
          one_space=one_space,
          attn_fn=attn_basis_fn,
          default_output=output_space.null_vector(),
          causal=causal,
          always_attend_to_bos=False,
          use_bos_for_default_output=True,
          softmax_coldness=100)
    elif isinstance(expr, rasp.SelectorWidth):
      sel_expr = expr.selector
      queries = graph.nodes[sel_expr.queries.label]
      keys = graph.nodes[sel_expr.keys.label]
      _check_selector_expression(expr, graph)

      bos_space = bases.VectorSpaceWithBasis([bos_dir])
      query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
      key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
      output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      # Argument order is different in craft / transformers than RASP selectors
      def attn_basis_fn(query: bases.BasisDirection,
                        key: bases.BasisDirection) -> bool:
        # It's okay to use the local variable sel_expr because this function is
        # only used within the same loop iteration to create an attention head.
        selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)  # pylint: disable=cell-var-from-loop
        return selector_basis_fn(key, query)

      block = selector_width.selector_width(
          query_space=query_space,
          key_space=key_space,
          output_space=output_space,
          bos_space=bos_space,
          one_space=one_space,
          attn_fn=attn_basis_fn,
          out_value_set=node[nodes.VALUE_SET],
          categorical_output=rasp.is_categorical(expr),
          causal=False,
          softmax_coldness=100,
          mlp_large_number=mlp_exactness,
          label=expr.label)
    else:
      raise NotImplementedError(f"Expression {expr} cannot be translated to "
                                "a model component.")

    graph.nodes[node_id][nodes.MODEL_BLOCK] = block