Spaces:
Runtime error
Runtime error
File size: 469 Bytes
c4bc1f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from transformers import ViTModel
from torch import nn
class ViTImageSearchModel(nn.Module):
def __init__(self, pretrained_model_name="google/vit-base-patch32-224-in21k"):
super(ViTImageSearchModel, self).__init__()
self.vit = ViTModel.from_pretrained(pretrained_model_name)
def forward(self, x): # noqa
outputs = self.vit(pixel_values=x)
cls_hidden_state = outputs.last_hidden_state[:, 0, :]
return cls_hidden_state
|