File size: 12,068 Bytes
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee5cdb4
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46567d
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RASP programs only using the subset of RASP supported by the compiler."""

from typing import List, Sequence

from tracr.rasp import rasp

### Programs that work only under non-causal evaluation.


def make_length() -> rasp.SOp:
  """Creates the `length` SOp using selector width primitive.

  Example usage:
    length = make_length()
    length("abcdefg")
    >> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]

  Returns:
    length: SOp mapping an input to a sequence, where every element
      is the length of that sequence.
  """
  all_true_selector = rasp.Select(
      rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
  return rasp.SelectorWidth(all_true_selector).named("length")


length = make_length()


def make_reverse(sop: rasp.SOp) -> rasp.SOp:
  """Create an SOp that reverses a sequence, using length primitive.

  Example usage:
    reverse = make_reverse(rasp.tokens)
    reverse("Hello")
    >> ['o', 'l', 'l', 'e', 'H']

  Args:
    sop: an SOp

  Returns:
    reverse : SOp that reverses the input sequence.
  """
  opp_idx = (length - rasp.indices).named("opp_idx")
  opp_idx = (opp_idx - 1).named("opp_idx-1")
  reverse_selector = rasp.Select(rasp.indices, opp_idx,
                                 rasp.Comparison.EQ).named("reverse_selector")
  return rasp.Aggregate(reverse_selector, sop).named("reverse")


def make_pair_balance(sop: rasp.SOp, open_token: str,
                      close_token: str) -> rasp.SOp:
  """Return fraction of previous open tokens minus the fraction of close tokens.

   (As implemented in the RASP paper.)

  If the outputs are always non-negative and end in 0, that implies the input
  has balanced parentheses.

  Example usage:
    num_l = make_pair_balance(rasp.tokens, "(", ")")
    num_l("a()b(c))")
    >> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8]

  Args:
    sop: Input SOp.
    open_token: Token that counts positive.
    close_token: Token that counts negative.

  Returns:
    pair_balance: SOp mapping an input to a sequence, where every element
      is the fraction of previous open tokens minus previous close tokens.
  """
  bools_open = rasp.numerical(sop == open_token).named("bools_open")
  opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens")

  bools_close = rasp.numerical(sop == close_token).named("bools_close")
  closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes")

  pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1))
  return pair_balance.named("pair_balance")


def make_shuffle_dyck(pairs: List[str]) -> rasp.SOp:
  """Returns 1 if a set of parentheses are balanced, 0 else.

   (As implemented in the RASP paper.)

  Example usage:
    shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"])
    shuffle_dyck2("({)}")
    >> [1, 1, 1, 1]
    shuffle_dyck2("(){)}")
    >> [0, 0, 0, 0, 0]

  Args:
    pairs: List of pairs of open and close tokens that each should be balanced.
  """
  assert len(pairs) >= 1

  # Compute running balance of each type of parenthesis
  balances = []
  for pair in pairs:
    assert len(pair) == 2
    open_token, close_token = pair
    balance = make_pair_balance(
        rasp.tokens, open_token=open_token,
        close_token=close_token).named(f"balance_{pair}")
    balances.append(balance)

  # Check if balances where negative anywhere -> parentheses not balanced
  any_negative = balances[0] < 0
  for balance in balances[1:]:
    any_negative = any_negative | (balance < 0)

  # Convert to numerical SOp
  any_negative = rasp.numerical(rasp.Map(lambda x: x,
                                         any_negative)).named("any_negative")

  select_all = rasp.Select(rasp.indices, rasp.indices,
                           rasp.Comparison.TRUE).named("select_all")
  has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative,
                                          default=0)).named("has_neg")

  # Check if all balances are 0 at the end -> closed all parentheses
  all_zero = balances[0] == 0
  for balance in balances[1:]:
    all_zero = all_zero & (balance == 0)

  select_last = rasp.Select(rasp.indices, length - 1,
                            rasp.Comparison.EQ).named("select_last")
  last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero")

  not_has_neg = (~has_neg).named("not_has_neg")
  return (last_zero & not_has_neg).named("shuffle_dyck")


def make_shuffle_dyck2() -> rasp.SOp:
  return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2")


def make_hist() -> rasp.SOp:
  """Returns the number of times each token occurs in the input.

   (As implemented in the RASP paper.)

  Example usage:
    hist = make_hist()
    hist("abac")
    >> [2, 1, 2, 1]
  """
  same_tok = rasp.Select(rasp.tokens, rasp.tokens,
                         rasp.Comparison.EQ).named("same_tok")
  return rasp.SelectorWidth(same_tok).named("hist")


def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
  """Returns vals sorted by < relation on keys.

  Only supports unique keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
  """
  smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
  target_pos = rasp.SelectorWidth(smaller).named("target_pos")
  sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
  return rasp.Aggregate(sel_new, vals).named("sort")


def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
              min_key: float) -> rasp.SOp:
  """Returns vals sorted by < relation on keys, which don't need to be unique.

  The implementation differs from the RASP paper, as it avoids using
  compositions of selectors to break ties. Instead, it uses the arguments
  max_seq_len and min_key to ensure the keys are unique.

  Note that this approach only works for numerical keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]
    sort([2, 4, 1, 2])
    >> [1, 2, 2, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
    max_seq_len: Maximum sequence length (used to ensure keys are unique)
    min_key: Minimum key value (used to ensure keys are unique)

  Returns:
    Output SOp of sort program.
  """
  keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
                          rasp.indices)
  return make_sort_unique(vals, keys)


def make_sort_freq(max_seq_len: int) -> rasp.SOp:
  """Returns tokens sorted by the frequency they appear in the input.

  Tokens the appear the same amount of times are output in the same order as in
  the input.

  Example usage:
    sort = make_sort_freq(rasp.tokens, rasp.tokens, 5)
    sort([2, 4, 2, 1])
    >> [2, 2, 4, 1]

  Args:
    max_seq_len: Maximum sequence length (used to ensure keys are unique)
  """
  hist = -1 * make_hist().named("hist")
  return make_sort(
      rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq")


### Programs that work under both causal and regular evaluation.


def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp:
  """Count the fraction of previous tokens where a specific condition was True.

   (As implemented in the RASP paper.)

  Example usage:
    num_l = make_frac_prevs(rasp.tokens=="l")
    num_l("hello")
    >> [0, 0, 1/3, 1/2, 2/5]

  Args:
    bools: SOp mapping a sequence to a sequence of booleans.

  Returns:
    frac_prevs: SOp mapping an input to a sequence, where every element
      is the fraction of previous "True" tokens.
  """
  bools = rasp.numerical(bools)
  prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
  return rasp.numerical(rasp.Aggregate(prevs, bools,
                                       default=0)).named("frac_prevs")


def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k + offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=None)
  return out.named(f"shift_by({offset})")


def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp:
  """Returns an SOp which is True at the final element of the pattern.

  The first len(pattern) - 1 elements of the output SOp are None-padded.

  detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]

  Args:
    sop: the SOp in which to look for patterns.
    pattern: a sequence of values to look for.

  Returns:
    a sop which detects the pattern.
  """

  if len(pattern) < 1:
    raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}")

  # detectors[i] will be a boolean-valued SOp which is true at position j iff
  # the i'th (from the end) element of the pattern was detected at position j-i.
  detectors = []
  for i, element in enumerate(reversed(pattern)):
    detector = sop == element
    if i != 0:
      detector = shift_by(i, detector)
    detectors.append(detector)

  # All that's left is to take the AND over all detectors.
  pattern_detected = detectors.pop()
  while detectors:
    pattern_detected = pattern_detected & detectors.pop()

  return pattern_detected.named(f"detect_pattern({pattern})")


def make_count_less_freq(n: int) -> rasp.SOp:
  """Returns how many tokens appear fewer than n times in the input.

  The output sequence contains this count in each position.

  Example usage:
    count_less_freq = make_count_less_freq(2)
    count_less_freq(["a", "a", "a", "b", "b", "c"])
    >> [3, 3, 3, 3, 3, 3]
    count_less_freq(["a", "a", "c", "b", "b", "c"])
    >> [6, 6, 6, 6, 6, 6]

  Args:
    n: Integer to compare token frequences to.
  """
  hist = make_hist().named("hist")
  select_less = rasp.Select(hist, hist,
                            lambda x, y: x <= n).named("select_less")
  return rasp.SelectorWidth(select_less).named("count_less_freq")


def make_count(sop, token):
  """Returns the count of `token` in `sop`.

  The output sequence contains this count in each position.

  Example usage:
    count = make_count(tokens, "a")
    count(["a", "a", "a", "b", "b", "c"])
    >> [3, 3, 3, 3, 3, 3]
    count(["c", "a", "b", "c"])
    >> [1, 1, 1, 1]

  Args:
    sop: Sop to count tokens in.
    token: Token to count.
  """
  return rasp.SelectorWidth(rasp.Select(
      sop, sop, lambda k, q: k == token)).named(f"count_{token}")


def make_nary_sequencemap(f, *sops):
  """Returns an SOp that simulates an n-ary SequenceMap.

  Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n
  into a single SOp arguments that takes n-tuples as value. The n-ary sequence
  map implementing f is then a Map on this resulting SOp.

  Note that the intermediate variables representing tuples of varying length
  will be encoded categorically, and can become very high-dimensional. So,
  using this function might lead to very large compiled models.

  Args:
    f: Function with n arguments.
    *sops: Sequence of SOps, one for each argument of f.
  """
  values, *sops = sops
  for sop in sops:
    # x is a single entry in the first iteration but a tuple in later iterations
    values = rasp.SequenceMap(
        lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop)
  return rasp.Map(lambda args: f(*args), values)