File size: 12,382 Bytes
7a58a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from types import MethodType
from functools import partial
import self_extend_patch as SE

def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
    """
        This function modifies the method of an instance of a model class. 
        It's part from chat-GPT.
        It will replace the method  with the new method.
        Currently, we only use this function to modify the attention method of a model. Do not test it further. 

        instance: 
            instance of a model to modify.
        target_class_name: 
            name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
        new_method: new method to replace the original method. E.g. 'self_extend_forward'. 
            It should include a parameter 'self' to be binded to the instance.
    """
    target_found = False
    if visited_instances is None:
        visited_instances = set()
    # Unique identifier for the instance (using id() since object's id is unique)
    instance_id = id(instance)
    if instance_id in visited_instances:
        target_found = False
        return target_found
    # Add the instance to the already_visited set
    visited_instances.add(instance_id)

    # Check if this instance is of the target class
    if instance.__class__.__name__ == target_class_name:
        bond_method = MethodType(new_method, instance) 
        setattr(instance, target_method_name, bond_method)
        target_found = True
        return target_found
    elif hasattr(instance, '__dict__'):
        for attr_name, attr_value in instance.__dict__.items():
            if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
                _found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
                if _found:
                    target_found = True
            elif isinstance(attr_value, (list, tuple)):
                for item in attr_value:
                    if isinstance(item, object):
                        _found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
                        if _found:
                            target_found = True
            # If attribute value is a dictionary, iterate over its values and recurse
            # E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
            elif isinstance(attr_value, dict):
                for key, value in attr_value.items():
                    if isinstance(value, object):
                        _found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
                        if _found:
                            target_found = True
            # If attribute value is a set, iterate and recurse
            elif isinstance(attr_value, set):
                for item in attr_value:
                    if isinstance(item, object):
                        _found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
                        if _found:
                            target_found = True

    return target_found


def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"):
    '''
        loaded_model: 
            model to apply the self-attention extension. 
        group_size: 
            group size for the self-attention extension. 
        window_size: 
            window size for the self-attention extension. 
        scale_base:
            base for the scale, equal to pretraining length. 
            e.g. 4096 for Llama, 8192 for Gemma

            Two recommended scale factor:
                yarn: https://arxiv.org/abs/2309.00071
                log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823
            This is helpful while retrieving a long sequence (e.g a long passkey).
            But on real-world data, the impact is minor. (e.g. on LongBench, LEval).

            The reported results in our paper does not use this scale except for long passkey retrieval.
    '''
    arch_name = loaded_model.__class__.__name__
    if 'Llama' in arch_name:
        if enable_flash_attention:
            if flash_attention_impl == "flash_attn":
                self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
                modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
                modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
                print("Using flash_attn flash self_extend!!")
                if (not modifed_1) or (not modifed_2):
                    raise Exception(f"Failed to modify the attention method of {arch_name}")

            elif flash_attention_impl == "triton":
                self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
                modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
                print("Using triton flash self_extend!!")
                if (not modifed):
                    raise Exception(f"Failed to modify the attention method of {arch_name}")
            else:
                raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.")


        else:
            self_extend_attention_forward = partial(SE.Llama.self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            # after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
            # print("loaded_model", loaded_model)
            modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
            if not modifed_2:
                raise Exception(f"Failed to modify the attention method of {arch_name}")
    elif 'Mistral' in arch_name:
        # Mistral shares the same architecture with Llama, so the implementation should be exchangable.
        if enable_flash_attention:
            self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
            modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward)
            if (not modifed_1) or (not modifed_2):
                raise Exception(f"Failed to modify the attention method of {arch_name}")
        else:
            self_extend_attention_forward = partial(SE.Mistral.self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward)
            if not modifed_2:
                raise Exception(f"Failed to modify the attention method of {arch_name}")
    elif 'Gemma' in arch_name:
        if enable_flash_attention:
            self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
            modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward)
            if (not modifed_1) or (not modifed_2):
                raise Exception(f"Failed to modify the attention method of {arch_name}")
        else:
            self_extend_attention_forward = partial(SE.Gemma.self_extend_forward,
                                            group_size_1=group_size,
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward)
            if not modifed_2:
                raise Exception(f"Failed to modify the attention method of {arch_name}")
    elif 'Qwen2' in arch_name:
        if enable_flash_attention:
            self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
            modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward)
            if (not modifed_1) or (not modifed_2):
                raise Exception(f"Failed to modify the attention method of {arch_name}")
        else:
            self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward)
            if not modifed_2:
                raise Exception(f"Failed to modify the attention method of {arch_name}")
    elif 'Phi' in arch_name:
        if enable_flash_attention:
            self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
            modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward)
            if (not modifed_1) or (not modifed_2):
                raise Exception(f"Failed to modify the attention method of {arch_name}")
        else:
            self_extend_attention_forward = partial(SE.Phi.self_extend_forward,
                                            group_size_1=group_size, 
                                            group_size_2=window_size,
                                            scale_base=scale_base)
            modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward)
            if not modifed_2:
                raise Exception(f"Failed to modify the attention method of {arch_name}")
    else:
        raise NotImplementedError