File size: 531 Bytes
0f8b77a
 
 
3226856
 
0377af8
 
 
b90a036
0f8b77a
2d1091a
a5c24a9
0f8b77a
b90a036
 
a5c24a9
b90a036
0377af8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers.integrations import TransformersPlugin, replace_target_class

from .llama_xformers_attention import LlamaXFormersAttention

class LlamaXFormersPlugin(TransformersPlugin):
    def __init__(self, config):
        pass

    def process_model_pre_init(self, model):
        model_config = model.config
        replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", init_kwargs={"config": model_config})
        return model
    

    def process_model_post_init(self, model):
        return model