mlm_llama / __init__.py
Shounak's picture
Update __init__.py
876699e verified
raw
history blame contribute delete
448 Bytes
from transformers import AutoModelForMaskedLM
from .modeling_custom_llama import CustomLlamaConfig, CustomLlamaForCausalLM, CustomLlamaForMaskedLM
from transformers import CONFIG_MAPPING, MODEL_MAPPING
# Assuming CustomLlamaConfig is your config class
AutoModelForMaskedLM.register(CustomLlamaConfig, CustomLlamaForMaskedLM)
CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig})
# MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM})