File size: 1,005 Bytes
e9acf97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import AutoConfig, AutoModelForCausalLM, SpeechEncoderDecoderConfig

from .auto_wrappers import CustomAutoModelForCTC
from .e_branchformer import Wav2Vec2EBranchformerConfig, Wav2Vec2EBranchformerForCTC
from .multi_head_gpt2 import GPT2LMMultiHeadModel, GPT2MultiHeadConfig
from .residual_clasiffier_gpt2 import (
    GPT2ResidualsLMHeadConfig,
    GPT2ResidualsLMHeadModel,
)

AutoConfig.register("gpt2-multi-head", GPT2MultiHeadConfig)
AutoModelForCausalLM.register(GPT2MultiHeadConfig, GPT2LMMultiHeadModel)

AutoConfig.register("gpt2-residuals-head", GPT2ResidualsLMHeadConfig)
AutoModelForCausalLM.register(GPT2ResidualsLMHeadConfig, GPT2ResidualsLMHeadModel)

AutoConfig.register("wav2vec2-ebranchformer", Wav2Vec2EBranchformerConfig)
CustomAutoModelForCTC.register(Wav2Vec2EBranchformerConfig, Wav2Vec2EBranchformerForCTC)


class JointCTCAttentionEncoderDecoderConfig(SpeechEncoderDecoderConfig):
    model_type = "joint_aed_ctc_speech-encoder-decoder"
    is_composition = True