from transformers import AutoModelForTokenClassification, PretrainedConfig import torch class ModernBertForTokenClassificationCRF(AutoModelForTokenClassification): def __init__(self, config): super().__init__(config) # Add CRF layer if needed if config.use_crf: from torchcrf import CRF self.crf = CRF(num_tags=config.num_labels, batch_first=True) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): config = kwargs.pop('config', None) if config is None: config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs)