|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import AutoConfig, AutoModel, AutoProcessor |
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
"""Huggingface AutoModel to generate token embeddings. |
|
Loads the correct class, e.g. BERT / RoBERTa etc. |
|
|
|
Args: |
|
model_name_or_path: Huggingface models name |
|
(https://huggingface.co/models) |
|
model_args: Keyword arguments passed to the Huggingface |
|
Transformers model |
|
tokenizer_args: Keyword arguments passed to the Huggingface |
|
Transformers tokenizer |
|
config_args: Keyword arguments passed to the Huggingface |
|
Transformers config |
|
cache_dir: Cache dir for Huggingface Transformers to store/load |
|
models |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
model_args: Optional[Dict[str, Any]] = None, |
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
config_args: Optional[Dict[str, Any]] = None, |
|
cache_dir: Optional[str] = None, |
|
) -> None: |
|
super(VisionTransformer, self).__init__() |
|
if model_args is None: |
|
model_args = {} |
|
if tokenizer_args is None: |
|
tokenizer_args = {} |
|
if config_args is None: |
|
config_args = {} |
|
|
|
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) |
|
self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args, cache_dir=cache_dir) |
|
self.processor = AutoProcessor.from_pretrained(model_name_or_path, config=self.config, **tokenizer_args, cache_dir=cache_dir) |
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
"""Returns token_embeddings, cls_token""" |
|
output_states = self.model(pixel_values=features["pixel_values"], return_dict=False)[0] |
|
features.update({"token_embeddings": output_states}) |
|
return features |
|
|
|
def get_word_embedding_dimension(self) -> int: |
|
return self.config.hidden_size |
|
|
|
def tokenize( |
|
self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True |
|
) -> Dict[str, torch.Tensor]: |
|
return self.processor(texts, return_tensors="pt") |
|
|
|
def get_config_dict(self) -> Dict[str, Any]: |
|
return {key: self.__dict__[key] for key in self.config_keys} |
|
|
|
def save(self, output_path: str, safe_serialization: bool = True) -> None: |
|
self.model.save_pretrained(output_path, safe_serialization=safe_serialization) |
|
self.processor.save_pretrained(output_path) |
|
|