Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Pieces for making transformers.""" | |
import abc | |
import dataclasses | |
from typing import Iterable, List, Optional, Sequence, Union | |
import numpy as np | |
from tracr.craft import bases | |
from tracr.craft import vectorspace_fns | |
project = vectorspace_fns.project | |
def _np_softmax(x, axis=-1): | |
x_max = np.max(x, axis=axis, keepdims=True) | |
return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True) | |
def _np_relu(x): | |
return np.where(x > 0, x, 0) | |
def relu(x: bases.VectorInBasis) -> bases.VectorInBasis: | |
return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes)) | |
class Block(abc.ABC): | |
"""Transformer block, acting on a sequence of vector space elements. | |
Attributes: | |
residual_space: Vector space that contains all subspaces the Block interacts | |
with. This can be either the full residual space of a model or a subspace. | |
""" | |
residual_space: bases.VectorSpaceWithBasis | |
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
"""Applies self to an input.""" | |
class AttentionHead(Block): | |
"""A transformer attention head.""" | |
w_qk: vectorspace_fns.ScalarBilinear | |
w_ov: vectorspace_fns.Linear | |
residual_space: Optional[bases.VectorSpaceWithBasis] = None | |
causal: bool = False | |
def __post_init__(self): | |
"""Infer residual stream and typecheck subspaces.""" | |
if self.residual_space is None: | |
self.residual_space = bases.join_vector_spaces(self.w_qk.left_space, | |
self.w_qk.right_space, | |
self.w_ov.input_space, | |
self.w_ov.output_space) | |
assert self.w_qk.left_space.issubspace(self.residual_space) | |
assert self.w_qk.right_space.issubspace(self.residual_space) | |
assert self.w_ov.input_space.issubspace(self.residual_space) | |
assert self.w_ov.output_space.issubspace(self.residual_space) | |
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
assert x in self.residual_space | |
# seq_len x query_space | |
queries = x.project(self.w_qk.left_space) | |
# seq_len x key_space | |
keys = x.project(self.w_qk.right_space) | |
attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T | |
if self.causal: | |
# The 1 gives us the matrix above the diagonal. | |
mask = np.triu(np.full_like(attn_matrix, -np.inf), 1) | |
attn_matrix = attn_matrix + mask | |
attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to | |
values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model | |
magnitudes = attn_weights @ values # seq_len_from, d_model | |
return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes) | |
def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
"""Wov but acting on the residual space.""" | |
x = project(self.residual_space, self.w_ov.input_space)(x) | |
out = self.w_ov(x) | |
return project(self.w_ov.output_space, self.residual_space)(out) | |
def num_heads(self) -> int: | |
return 1 | |
def as_multi(self) -> "MultiAttentionHead": | |
return MultiAttentionHead([self]) | |
class MultiAttentionHead(Block): | |
"""Applies attention heads in parallel.""" | |
sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]] | |
def __post_init__(self): | |
spaces = [block.residual_space for block in self.sub_blocks] | |
self.residual_space, *others = spaces | |
assert all(s == self.residual_space for s in others) | |
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
# each element is seq_len x embedding | |
outs = [block.apply(x) for block in self.sub_blocks] | |
return bases.VectorInBasis.sum(outs) # seq_len x embedding | |
def num_heads(self) -> int: | |
return sum(sub_block.num_heads for sub_block in self.sub_blocks) | |
def heads(self) -> Iterable[AttentionHead]: | |
for sub_block in self.sub_blocks: | |
if isinstance(sub_block, AttentionHead): | |
yield sub_block | |
elif isinstance(sub_block, MultiAttentionHead): | |
yield from sub_block.heads() | |
else: | |
raise NotImplementedError() | |
def as_multi(self) -> "MultiAttentionHead": | |
return self | |
class MLP(Block): | |
"""A transformer MLP block.""" | |
fst: vectorspace_fns.Linear | |
snd: vectorspace_fns.Linear | |
residual_space: Optional[bases.VectorSpaceWithBasis] = None | |
def __post_init__(self): | |
"""Typecheck subspaces.""" | |
if self.residual_space is None: | |
self.residual_space = bases.join_vector_spaces(self.fst.input_space, | |
self.snd.output_space) | |
assert self.fst.output_space == self.snd.input_space | |
assert self.fst.input_space.issubspace(self.residual_space) | |
assert self.snd.output_space.issubspace(self.residual_space) | |
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
assert x in self.residual_space | |
x = project(self.residual_space, self.fst.input_space)(x) | |
hidden = self.fst(x) | |
hidden = relu(hidden) | |
out = self.snd(hidden) | |
return project(self.snd.output_space, self.residual_space)(out) | |
def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP": | |
fst = vectorspace_fns.Linear.combine_in_parallel( | |
[block.fst for block in mlps]) | |
snd = vectorspace_fns.Linear.combine_in_parallel( | |
[block.snd for block in mlps]) | |
return cls(fst=fst, snd=snd, residual_space=None) | |
# Block that fits into a half-layer, without residual connections. | |
HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead] | |
class SeriesWithResiduals(Block): | |
"""A series of blocks with residual connections.""" | |
blocks: List[HalfLayerBlock] | |
def __post_init__(self): | |
spaces = [block.residual_space for block in self.blocks] | |
self.residual_space = bases.join_vector_spaces(*spaces) | |
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
x = x.project(self.residual_space) | |
for block in self.blocks: | |
x_in = x.project(block.residual_space) | |
x_out = block.apply(x_in).project(self.residual_space) | |
x = x + x_out | |
return x | |