File size: 560 Bytes
2f9282b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
from fla.models.delta_net.modeling_delta_net import (DeltaNetForCausalLM,
DeltaNetModel)
AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
AutoModel.register(DeltaNetConfig, DeltaNetModel)
AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
|