from m4.models.custom_modules import DecoupledEmbedding, DecoupledLinear from m4.models.vbloom.configuration_vbloom import VBloomConfig from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel from m4.models.vllama.configuration_vllama import VLlamaConfig from m4.models.vllama.modeling_vllama import VLlamaForCausalLM from m4.models.vopt.configuration_vopt import VOPTConfig from m4.models.vopt.modeling_vopt import VOPTForCausalLM from m4.models.vt5.configuration_vt5 import VT5Config from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration _SUPPORTED_MODELS = { "vgpt2": VGPT2Config, "vt5": VT5Config, "vbloom": VBloomConfig, "vopt": VOPTConfig, "vllama": VLlamaConfig, } model_type_to_modeling_class = { "vgpt2": VGPT2LMHeadModel, "vt5": VT5ForConditionalGeneration, "vbloom": VBloomForCausalLM, "vopt": VOPTForCausalLM, "vllama": VLlamaForCausalLM, }