from transformers import PreTrainedModel from timm.models.resnet import BasicBlock, Bottleneck, ResNet from transformers import PretrainedConfig from typing import List import torch import timm class ViTMAEConfig(PretrainedConfig): model_type = "vit_mae_custom" def __init__( self, model_name='timm/vit_base_patch16_224.mae', num_classes: int = 1000, **kwargs ): self.model_name = model_name self.num_classes = num_classes super().__init__(**kwargs) # 'timm/vit_huge_patch14_224.mae' # class ViTMAEModel(PreTrainedModel): # config_class = ViTMAEConfig # def __init__(self, config): # super().__init__(config) # self.model = timm.create_model(config.model_name, num_classes=config.num_classes, pretrained=True) # def forward(self, tensor): # return self.model.forward_features(tensor) class ViTMAEModelForImageClassification(PreTrainedModel): config_class = ViTMAEConfig def __init__(self, config): super().__init__(config) self.model = timm.create_model(config.model_name, num_classes=config.num_classes, pretrained=True) def forward(self, tensor, labels=None): logits = self.model(tensor) if labels is not None: loss = torch.nn.cross_entropy(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits}