File size: 448 Bytes
876699e
 
 
 
 
 
 
1
2
3
4
5
6
7
from transformers import AutoModelForMaskedLM
from .modeling_custom_llama import CustomLlamaConfig, CustomLlamaForCausalLM, CustomLlamaForMaskedLM
from transformers import CONFIG_MAPPING, MODEL_MAPPING
# Assuming CustomLlamaConfig is your config class
AutoModelForMaskedLM.register(CustomLlamaConfig, CustomLlamaForMaskedLM)
CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig})
# MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM})