|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import List, Optional |
|
|
|
|
|
class NodeType(Enum): |
|
AFTER_ATTN = "after_attn" |
|
AFTER_FFN = "after_ffn" |
|
FFN = "ffn" |
|
ORIGINAL = "original" |
|
|
|
|
|
def _format_block_hierachy_string(blocks: List[str]) -> str: |
|
return " โธ ".join(blocks) |
|
|
|
|
|
@dataclass |
|
class GraphNode: |
|
layer: int |
|
token: int |
|
type: NodeType |
|
|
|
def is_in_residual_stream(self) -> bool: |
|
return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN] |
|
|
|
def get_residual_predecessor(self) -> Optional["GraphNode"]: |
|
""" |
|
Get another graph node which points to the state of the residual stream before |
|
this node. |
|
|
|
Retun None if current representation is the first one in the residual stream. |
|
""" |
|
scheme = { |
|
NodeType.AFTER_ATTN: GraphNode( |
|
layer=max(self.layer - 1, 0), |
|
token=self.token, |
|
type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL, |
|
), |
|
NodeType.AFTER_FFN: GraphNode( |
|
layer=self.layer, |
|
token=self.token, |
|
type=NodeType.AFTER_ATTN, |
|
), |
|
NodeType.FFN: GraphNode( |
|
layer=self.layer, |
|
token=self.token, |
|
type=NodeType.AFTER_ATTN, |
|
), |
|
NodeType.ORIGINAL: None, |
|
} |
|
node = scheme[self.type] |
|
if node.layer < 0: |
|
return None |
|
return node |
|
|
|
def get_name(self) -> str: |
|
return _format_block_hierachy_string( |
|
[f"L{self.layer}", f"T{self.token}", str(self.type.value)] |
|
) |
|
|
|
def get_predecessor_block_name(self) -> str: |
|
""" |
|
Return the name of the block standing between current node and its predecessor |
|
in the residual stream. |
|
""" |
|
scheme = { |
|
NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"], |
|
NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"], |
|
NodeType.FFN: [f"L{self.layer}", "ffn"], |
|
NodeType.ORIGINAL: ["Nothing"], |
|
} |
|
return _format_block_hierachy_string(scheme[self.type]) |
|
|
|
def get_head_name(self, head: Optional[int]) -> str: |
|
path = [f"L{self.layer}", "attn"] |
|
if head is not None: |
|
path.append(f"H{head}") |
|
return _format_block_hierachy_string(path) |
|
|
|
def get_neuron_name(self, neuron: Optional[int]) -> str: |
|
path = [f"L{self.layer}", "ffn"] |
|
if neuron is not None: |
|
path.append(f"N{neuron}") |
|
return _format_block_hierachy_string(path) |
|
|