from types import MethodType from functools import partial import self_extend_patch as SE def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None): """ This function modifies the method of an instance of a model class. It's part from chat-GPT. It will replace the method with the new method. Currently, we only use this function to modify the attention method of a model. Do not test it further. instance: instance of a model to modify. target_class_name: name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc. new_method: new method to replace the original method. E.g. 'self_extend_forward'. It should include a parameter 'self' to be binded to the instance. """ target_found = False if visited_instances is None: visited_instances = set() # Unique identifier for the instance (using id() since object's id is unique) instance_id = id(instance) if instance_id in visited_instances: target_found = False return target_found # Add the instance to the already_visited set visited_instances.add(instance_id) # Check if this instance is of the target class if instance.__class__.__name__ == target_class_name: bond_method = MethodType(new_method, instance) setattr(instance, target_method_name, bond_method) target_found = True return target_found elif hasattr(instance, '__dict__'): for attr_name, attr_value in instance.__dict__.items(): if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)): _found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances) if _found: target_found = True elif isinstance(attr_value, (list, tuple)): for item in attr_value: if isinstance(item, object): _found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances) if _found: target_found = True # If attribute value is a dictionary, iterate over its values and recurse # E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules elif isinstance(attr_value, dict): for key, value in attr_value.items(): if isinstance(value, object): _found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances) if _found: target_found = True # If attribute value is a set, iterate and recurse elif isinstance(attr_value, set): for item in attr_value: if isinstance(item, object): _found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances) if _found: target_found = True return target_found def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"): ''' loaded_model: model to apply the self-attention extension. group_size: group size for the self-attention extension. window_size: window size for the self-attention extension. scale_base: base for the scale, equal to pretraining length. e.g. 4096 for Llama, 8192 for Gemma Two recommended scale factor: yarn: https://arxiv.org/abs/2309.00071 log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823 This is helpful while retrieving a long sequence (e.g a long passkey). But on real-world data, the impact is minor. (e.g. on LongBench, LEval). The reported results in our paper does not use this scale except for long passkey retrieval. ''' arch_name = loaded_model.__class__.__name__ if 'Llama' in arch_name: if enable_flash_attention: if flash_attention_impl == "flash_attn": self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward) print("Using flash_attn flash self_extend!!") if (not modifed_1) or (not modifed_2): raise Exception(f"Failed to modify the attention method of {arch_name}") elif flash_attention_impl == "triton": self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward) print("Using triton flash self_extend!!") if (not modifed): raise Exception(f"Failed to modify the attention method of {arch_name}") else: raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.") else: self_extend_attention_forward = partial(SE.Llama.self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) # after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention # print("loaded_model", loaded_model) modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward) if not modifed_2: raise Exception(f"Failed to modify the attention method of {arch_name}") elif 'Mistral' in arch_name: # Mistral shares the same architecture with Llama, so the implementation should be exchangable. if enable_flash_attention: self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward) if (not modifed_1) or (not modifed_2): raise Exception(f"Failed to modify the attention method of {arch_name}") else: self_extend_attention_forward = partial(SE.Mistral.self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward) if not modifed_2: raise Exception(f"Failed to modify the attention method of {arch_name}") elif 'Gemma' in arch_name: if enable_flash_attention: self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward) if (not modifed_1) or (not modifed_2): raise Exception(f"Failed to modify the attention method of {arch_name}") else: self_extend_attention_forward = partial(SE.Gemma.self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward) if not modifed_2: raise Exception(f"Failed to modify the attention method of {arch_name}") elif 'Qwen2' in arch_name: if enable_flash_attention: self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward) if (not modifed_1) or (not modifed_2): raise Exception(f"Failed to modify the attention method of {arch_name}") else: self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward) if not modifed_2: raise Exception(f"Failed to modify the attention method of {arch_name}") elif 'Phi' in arch_name: if enable_flash_attention: self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward) if (not modifed_1) or (not modifed_2): raise Exception(f"Failed to modify the attention method of {arch_name}") else: self_extend_attention_forward = partial(SE.Phi.self_extend_forward, group_size_1=group_size, group_size_2=window_size, scale_base=scale_base) modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward) if not modifed_2: raise Exception(f"Failed to modify the attention method of {arch_name}") else: raise NotImplementedError