|
from transformers import AutoModelForTokenClassification, PretrainedConfig |
|
import torch |
|
|
|
class ModernBertForTokenClassificationCRF(AutoModelForTokenClassification): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if config.use_crf: |
|
from torchcrf import CRF |
|
self.crf = CRF(num_tags=config.num_labels, batch_first=True) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
config = kwargs.pop('config', None) |
|
if config is None: |
|
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) |
|
return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs) |