vitmodel2 / modeling_vitmodel.py
tmshag1's picture
Upload model
01d1aa1
raw
history blame
493 Bytes
import torch
import timm
from transformers import PreTrainedModel
from .configuration_vitmodel import ViTConfig
class VitMemModel(PreTrainedModel):
config_class = ViTConfig
def __init__(self, config: ViTConfig):
super().__init__(config)
self.model = timm.create_model("vit_base_patch16_224_miil", pretrained=False, num_classes=1)
def forward(self, tensor, labels=None):
vitfeat = self.model(tensor)
out = torch.sigmoid(vitfeat)
return out