TriEightz's picture
Upload model
e344930 verified
from transformers import PretrainedConfig, PreTrainedModel
from typing import List
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
class ConvNextBaseConfig(PretrainedConfig):
model_type = "ConvNext"
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
class ConvNextBaseModel(PreTrainedModel):
config_class = ConvNextBaseConfig
def __init__(self, config):
super().__init__(config)
self.model = convnext_base()
def forward(self, tensor):
return self.model(tensor)
class ConvNextBaseModelForImageClassification(PreTrainedModel):
config_class = ConvNextBaseConfig
def __init__(self, config):
super().__init__(config)
self.model = convnext_base()
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
ConvNextBaseConfig.register_for_auto_class()
ConvNextBaseModel.register_for_auto_class("AutoModel")
ConvNextBaseModelForImageClassification.register_for_auto_class("AutoModelForImageClassification")