mlame's picture
Upload 6 files
bfeaab7 verified
from transformers import AutoModelForTokenClassification, PretrainedConfig
import torch
class ModernBertForTokenClassificationCRF(AutoModelForTokenClassification):
def __init__(self, config):
super().__init__(config)
# Add CRF layer if needed
if config.use_crf:
from torchcrf import CRF
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
config = kwargs.pop('config', None)
if config is None:
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs)