from transformers import PretrainedConfig, PreTrainedModel from typing import List from torchvision.models import convnext_base, ConvNeXt_Base_Weights class ConvNextBaseConfig(PretrainedConfig): model_type = "ConvNext" def __init__( self, **kwargs, ): super().__init__(**kwargs) class ConvNextBaseModel(PreTrainedModel): config_class = ConvNextBaseConfig def __init__(self, config): super().__init__(config) self.model = convnext_base() def forward(self, tensor): return self.model(tensor) class ConvNextBaseModelForImageClassification(PreTrainedModel): config_class = ConvNextBaseConfig def __init__(self, config): super().__init__(config) self.model = convnext_base() def forward(self, tensor, labels=None): logits = self.model(tensor) if labels is not None: loss = torch.nn.cross_entropy(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits} ConvNextBaseConfig.register_for_auto_class() ConvNextBaseModel.register_for_auto_class("AutoModel") ConvNextBaseModelForImageClassification.register_for_auto_class("AutoModelForImageClassification")