File size: 493 Bytes
01d1aa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import timm
from transformers import PreTrainedModel
from .configuration_vitmodel import ViTConfig

class VitMemModel(PreTrainedModel):
    config_class = ViTConfig

    def __init__(self, config: ViTConfig):
        super().__init__(config)
        self.model = timm.create_model("vit_base_patch16_224_miil", pretrained=False, num_classes=1)

    def forward(self, tensor, labels=None):
        vitfeat = self.model(tensor)
        out = torch.sigmoid(vitfeat)
        return out