Spaces:
Runtime error
Runtime error
from transformers import ViTModel | |
from torch import nn | |
class ViTImageSearchModel(nn.Module): | |
def __init__(self, pretrained_model_name="google/vit-base-patch32-224-in21k"): | |
super(ViTImageSearchModel, self).__init__() | |
self.vit = ViTModel.from_pretrained(pretrained_model_name) | |
def forward(self, x): # noqa | |
outputs = self.vit(pixel_values=x) | |
cls_hidden_state = outputs.last_hidden_state[:, 0, :] | |
return cls_hidden_state | |