from transformers import AutoModelForMaskedLM | |
from .modeling_custom_llama import CustomLlamaConfig, CustomLlamaForCausalLM, CustomLlamaForMaskedLM | |
from transformers import CONFIG_MAPPING, MODEL_MAPPING | |
# Assuming CustomLlamaConfig is your config class | |
AutoModelForMaskedLM.register(CustomLlamaConfig, CustomLlamaForMaskedLM) | |
CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig}) | |
# MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM}) |