from transformers.integrations import TransformersPlugin, replace_target_class | |
from .llama_xformers_attention import LlamaXFormersAttention | |
class LlamaXFormersPlugin(TransformersPlugin): | |
def __init__(self, config): | |
pass | |
def process_model_pre_init(self, model): | |
model_config = model.config | |
replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", init_kwargs={"config": model_config}) | |
return model | |
def process_model_post_init(self, model): | |
return model | |