File size: 2,831 Bytes
ce00289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from enum import Enum
from typing import List, Optional


class NodeType(Enum):
    AFTER_ATTN = "after_attn"
    AFTER_FFN = "after_ffn"
    FFN = "ffn"
    ORIGINAL = "original"  # The original tokens


def _format_block_hierachy_string(blocks: List[str]) -> str:
    return " ▸ ".join(blocks)


@dataclass
class GraphNode:
    layer: int
    token: int
    type: NodeType

    def is_in_residual_stream(self) -> bool:
        return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN]

    def get_residual_predecessor(self) -> Optional["GraphNode"]:
        """
        Get another graph node which points to the state of the residual stream before
        this node.

        Retun None if current representation is the first one in the residual stream.
        """
        scheme = {
            NodeType.AFTER_ATTN: GraphNode(
                layer=max(self.layer - 1, 0),
                token=self.token,
                type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL,
            ),
            NodeType.AFTER_FFN: GraphNode(
                layer=self.layer,
                token=self.token,
                type=NodeType.AFTER_ATTN,
            ),
            NodeType.FFN: GraphNode(
                layer=self.layer,
                token=self.token,
                type=NodeType.AFTER_ATTN,
            ),
            NodeType.ORIGINAL: None,
        }
        node = scheme[self.type]
        if node.layer < 0:
            return None
        return node

    def get_name(self) -> str:
        return _format_block_hierachy_string(
            [f"L{self.layer}", f"T{self.token}", str(self.type.value)]
        )

    def get_predecessor_block_name(self) -> str:
        """
        Return the name of the block standing between current node and its predecessor
        in the residual stream.
        """
        scheme = {
            NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"],
            NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"],
            NodeType.FFN: [f"L{self.layer}", "ffn"],
            NodeType.ORIGINAL: ["Nothing"],
        }
        return _format_block_hierachy_string(scheme[self.type])

    def get_head_name(self, head: Optional[int]) -> str:
        path = [f"L{self.layer}", "attn"]
        if head is not None:
            path.append(f"H{head}")
        return _format_block_hierachy_string(path)

    def get_neuron_name(self, neuron: Optional[int]) -> str:
        path = [f"L{self.layer}", "ffn"]
        if neuron is not None:
            path.append(f"N{neuron}")
        return _format_block_hierachy_string(path)