File size: 448 Bytes
876699e |
1 2 3 4 5 6 7 |
from transformers import AutoModelForMaskedLM
from .modeling_custom_llama import CustomLlamaConfig, CustomLlamaForCausalLM, CustomLlamaForMaskedLM
from transformers import CONFIG_MAPPING, MODEL_MAPPING
# Assuming CustomLlamaConfig is your config class
AutoModelForMaskedLM.register(CustomLlamaConfig, CustomLlamaForMaskedLM)
CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig})
# MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM}) |