nomic-embed-vision-v1.5-st / vision_transformer.py
Tom Aarsen
Add custom Sentence Transformer module
c59e72d
raw
history blame contribute delete
No virus
2.66 kB
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoProcessor
class VisionTransformer(nn.Module):
"""Huggingface AutoModel to generate token embeddings.
Loads the correct class, e.g. BERT / RoBERTa etc.
Args:
model_name_or_path: Huggingface models name
(https://huggingface.co/models)
model_args: Keyword arguments passed to the Huggingface
Transformers model
tokenizer_args: Keyword arguments passed to the Huggingface
Transformers tokenizer
config_args: Keyword arguments passed to the Huggingface
Transformers config
cache_dir: Cache dir for Huggingface Transformers to store/load
models
"""
def __init__(
self,
model_name_or_path: str,
model_args: Optional[Dict[str, Any]] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
config_args: Optional[Dict[str, Any]] = None,
cache_dir: Optional[str] = None,
) -> None:
super(VisionTransformer, self).__init__()
if model_args is None:
model_args = {}
if tokenizer_args is None:
tokenizer_args = {}
if config_args is None:
config_args = {}
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args, cache_dir=cache_dir)
self.processor = AutoProcessor.from_pretrained(model_name_or_path, config=self.config, **tokenizer_args, cache_dir=cache_dir)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Returns token_embeddings, cls_token"""
output_states = self.model(pixel_values=features["pixel_values"], return_dict=False)[0]
features.update({"token_embeddings": output_states})
return features
def get_word_embedding_dimension(self) -> int:
return self.config.hidden_size
def tokenize(
self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True
) -> Dict[str, torch.Tensor]:
return self.processor(texts, return_tensors="pt")
def get_config_dict(self) -> Dict[str, Any]:
return {key: self.__dict__[key] for key in self.config_keys}
def save(self, output_path: str, safe_serialization: bool = True) -> None:
self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
self.processor.save_pretrained(output_path)