File size: 493 Bytes
01d1aa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
import timm
from transformers import PreTrainedModel
from .configuration_vitmodel import ViTConfig
class VitMemModel(PreTrainedModel):
config_class = ViTConfig
def __init__(self, config: ViTConfig):
super().__init__(config)
self.model = timm.create_model("vit_base_patch16_224_miil", pretrained=False, num_classes=1)
def forward(self, tensor, labels=None):
vitfeat = self.model(tensor)
out = torch.sigmoid(vitfeat)
return out |