mahnerak's picture
Initial Commit πŸš€
ce00289
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
class NodeType(Enum):
AFTER_ATTN = "after_attn"
AFTER_FFN = "after_ffn"
FFN = "ffn"
ORIGINAL = "original" # The original tokens
def _format_block_hierachy_string(blocks: List[str]) -> str:
return " β–Έ ".join(blocks)
@dataclass
class GraphNode:
layer: int
token: int
type: NodeType
def is_in_residual_stream(self) -> bool:
return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN]
def get_residual_predecessor(self) -> Optional["GraphNode"]:
"""
Get another graph node which points to the state of the residual stream before
this node.
Retun None if current representation is the first one in the residual stream.
"""
scheme = {
NodeType.AFTER_ATTN: GraphNode(
layer=max(self.layer - 1, 0),
token=self.token,
type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL,
),
NodeType.AFTER_FFN: GraphNode(
layer=self.layer,
token=self.token,
type=NodeType.AFTER_ATTN,
),
NodeType.FFN: GraphNode(
layer=self.layer,
token=self.token,
type=NodeType.AFTER_ATTN,
),
NodeType.ORIGINAL: None,
}
node = scheme[self.type]
if node.layer < 0:
return None
return node
def get_name(self) -> str:
return _format_block_hierachy_string(
[f"L{self.layer}", f"T{self.token}", str(self.type.value)]
)
def get_predecessor_block_name(self) -> str:
"""
Return the name of the block standing between current node and its predecessor
in the residual stream.
"""
scheme = {
NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"],
NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"],
NodeType.FFN: [f"L{self.layer}", "ffn"],
NodeType.ORIGINAL: ["Nothing"],
}
return _format_block_hierachy_string(scheme[self.type])
def get_head_name(self, head: Optional[int]) -> str:
path = [f"L{self.layer}", "attn"]
if head is not None:
path.append(f"H{head}")
return _format_block_hierachy_string(path)
def get_neuron_name(self, neuron: Optional[int]) -> str:
path = [f"L{self.layer}", "ffn"]
if neuron is not None:
path.append(f"N{neuron}")
return _format_block_hierachy_string(path)