File size: 5,377 Bytes
215c4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
from SDLens.cache_and_edit.hooks import fix_inf_values_hook, register_general_hook
import torch

class ModelActivationCache(ABC):
    """
    Cache for inference pass of a Diffusion Transformer.
    Used to cache residual-streams and activations.
    """
    def __init__(self):
    
        # Initialize caches for "double transformer" blocks using the subclass-defined NUM_TRANSFORMER_BLOCKS
        if hasattr(self, 'NUM_TRANSFORMER_BLOCKS'):
            self.image_residual = []
            self.image_activation = []
            self.text_residual = []
            self.text_activation = []

        # Initialize caches for "single transformer" blocks if defined (using NUM_SINGLE_TRANSFORMER_BLOCKS)
        if hasattr(self, 'NUM_SINGLE_TRANSFORMER_BLOCKS'):
            self.text_image_residual = []
            self.text_image_activation = []

    def __str__(self):
        lines = [f"{self.__class__.__name__}:"]
        for attr_name, value in self.__dict__.items():
            if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value):
                shapes = value[0].shape
                lines.append(f"  {attr_name}: len={len(value)}, shapes={shapes}")
            else:
                lines.append(f"  {attr_name}: {type(value)}")
        return "\n".join(lines)

    def _repr_pretty_(self, p, cycle):
        p.text(str(self))

    @abstractmethod
    def get_cache_info(self):
        """
        Return details about the cache configuration.
        Subclasses must implement this to provide info on their transformer block counts.
        """
        pass


class FluxActivationCache(ModelActivationCache):
    # Define number of blocks for double and single transformer caches
    NUM_TRANSFORMER_BLOCKS = 19
    NUM_SINGLE_TRANSFORMER_BLOCKS = 38

    def __init__(self):
        super().__init__()

    def get_cache_info(self):
        return {
            "transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
            "single_transformer_blocks": self.NUM_SINGLE_TRANSFORMER_BLOCKS,
        }
    
    def __getitem__(self, key):
        return getattr(self, key)


class PixartActivationCache(ModelActivationCache):
    # Define number of blocks for the double transformer cache only
    NUM_TRANSFORMER_BLOCKS = 28

    def __init__(self):
        super().__init__()

    def get_cache_info(self):
        return {
            "double_transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
        }


class ActivationCacheHandler:
    """ Used to manage ModelActivationCache of a Diffusion Transformer.
    """

    def __init__(self, cache: ModelActivationCache, positions_to_cache: List[str] = None):
        """Constructor.

        Args:
            cache (ModelActivationCache): cache to be used to store tensors.
            positions_to_cache (List[str], optional): name of modules to cached. 
                If None, all modules as specified in `cache.get_cache_info()` will be cached. Defaults to None.

        Raises:
            NotImplementedError: _description_

        Returns:
            _type_: _description_
        """
        self.cache = cache
        self.positions_to_cache = positions_to_cache

    @torch.no_grad()
    def cache_residual_and_activation_hook(self, *args):
        """ 
            To be used as a forward hook on a Transformer Block.
            It caches both residual_stream and activation (defined as output - residual_stream).
        """

        if len(args) == 3:
            module, input, output = args
        elif len(args) == 4:
            module, input, kwinput, output = args

        if isinstance(module, FluxTransformerBlock):
            encoder_hidden_states = output[0]            
            hidden_states = output[1]

            self.cache.image_activation.append(hidden_states - kwinput["hidden_states"])
            self.cache.text_activation.append(encoder_hidden_states - kwinput["encoder_hidden_states"])
            self.cache.image_residual.append(kwinput["hidden_states"])
            self.cache.text_residual.append(kwinput["encoder_hidden_states"])

        elif isinstance(module, FluxSingleTransformerBlock):
            self.cache.text_image_activation.append(output - kwinput["hidden_states"])
            self.cache.text_image_residual.append(kwinput["hidden_states"])
        else:
            raise NotImplementedError(f"Caching not implemented for {type(module)}")


    @property
    def forward_hooks_dict(self):
        
        # insert cache storing in dict
        hooks = defaultdict(list)

        if self.positions_to_cache is None:
            for block_type, num_layers in self.cache.get_cache_info().items():
                for i in range(num_layers):
                    module_name: str = f"transformer.{block_type}.{i}"
                    hooks[module_name].append(fix_inf_values_hook)
                    hooks[module_name].append(self.cache_residual_and_activation_hook)
        else:
            for module_name in self.positions_to_cache:
                hooks[module_name].append(fix_inf_values_hook)
                hooks[module_name].append(self.cache_residual_and_activation_hook)

        return hooks