|
from transformers import PretrainedConfig, PreTrainedModel |
|
from typing import List |
|
from torchvision.models import convnext_base, ConvNeXt_Base_Weights |
|
|
|
|
|
class ConvNextBaseConfig(PretrainedConfig): |
|
model_type = "ConvNext" |
|
|
|
def __init__( |
|
self, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class ConvNextBaseModel(PreTrainedModel): |
|
config_class = ConvNextBaseConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = convnext_base() |
|
|
|
def forward(self, tensor): |
|
return self.model(tensor) |
|
|
|
|
|
class ConvNextBaseModelForImageClassification(PreTrainedModel): |
|
config_class = ConvNextBaseConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = convnext_base() |
|
|
|
def forward(self, tensor, labels=None): |
|
logits = self.model(tensor) |
|
if labels is not None: |
|
loss = torch.nn.cross_entropy(logits, labels) |
|
return {"loss": loss, "logits": logits} |
|
return {"logits": logits} |
|
|
|
ConvNextBaseConfig.register_for_auto_class() |
|
ConvNextBaseModel.register_for_auto_class("AutoModel") |
|
ConvNextBaseModelForImageClassification.register_for_auto_class("AutoModelForImageClassification") |