import base64 import json import os from io import BytesIO from typing import Any, Dict, List, Optional, Union import requests import torch from PIL import Image from torch import nn from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer class Transformer(nn.Module): def __init__( self, model_name_or_path: str, tokenizer_name_or_path: Optional[str] = None, image_processor_name_or_path: Optional[str] = None, max_seq_length: Optional[int] = None, config_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, image_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super(Transformer, self).__init__() config_kwargs = config_kwargs or {} model_kwargs = model_kwargs or {} tokenizer_kwargs = tokenizer_kwargs or {} image_processor_kwargs = image_processor_kwargs or {} config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs) self.model = AutoModel.from_pretrained( model_name_or_path, config=config, **model_kwargs ) if max_seq_length is not None and "model_max_length" not in tokenizer_kwargs: tokenizer_kwargs["model_max_length"] = max_seq_length self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path or model_name_or_path, **tokenizer_kwargs, ) self.image_processor = AutoImageProcessor.from_pretrained( image_processor_name_or_path or model_name_or_path, **image_processor_kwargs, ) # No max_seq_length set. Try to infer from model if max_seq_length is None: if ( hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings") and hasattr(self.tokenizer, "model_max_length") ): max_seq_length = min( self.model.config.max_position_embeddings, self.tokenizer.model_max_length, ) self.max_seq_length = max_seq_length if tokenizer_name_or_path is not None: self.model.config.tokenizer_class = self.tokenizer.__class__.__name__ @staticmethod def _decode_data_image(data_image_str: str) -> Image.Image: header, data = data_image_str.split(",", 1) image_data = base64.b64decode(data) return Image.open(BytesIO(image_data)) def tokenize( self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True ) -> Dict[str, torch.Tensor]: """ Encodes input samples. Text samples are tokenized. Image URLs, image data buffers and PIL images are passed through the image processor. """ _images = [] _texts = [] _image_or_text_descriptors = [] for sample in texts: if isinstance(sample, str): if sample.startswith("http"): response = requests.get(sample) _images.append(Image.open(BytesIO(response.content)).convert("RGB")) _image_or_text_descriptors.append(0) elif sample.startswith("data:image/"): _images.append(self._decode_data_image(sample).convert("RGB")) _image_or_text_descriptors.append(0) else: try: _images.append(Image.open(sample).convert("RGB")) _image_or_text_descriptors.append(0) except Exception as e: _ = str(e) _texts.append(sample) _image_or_text_descriptors.append(1) elif isinstance(sample, Image.Image): _images.append(sample.convert("RGB")) _image_or_text_descriptors.append(0) encoding = {} if len(_texts): encoding["input_ids"] = self.tokenizer( texts, padding=padding, truncation="longest_first", return_tensors="pt", max_length=self.max_seq_length, ).input_ids if len(_images): encoding["pixel_values"] = self.image_processor( _images, return_tensors="pt" ).pixel_values encoding["image_text_info"] = _image_or_text_descriptors return encoding def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: image_embeddings = [] text_embeddings = [] if "pixel_values" in features: image_embeddings = self.model.get_image_features(features["pixel_values"]) if "input_ids" in features: text_embeddings = self.model.get_text_features(features["input_ids"]) sentence_embedding = [] image_features = iter(image_embeddings) text_features = iter(text_embeddings) for _, _input_type in enumerate(features["image_text_info"]): if _input_type == 0: sentence_embedding.append(next(image_features)) else: sentence_embedding.append(next(text_features)) features["sentence_embedding"] = torch.stack(sentence_embedding).float() return features def save(self, output_path: str, safe_serialization: bool = True) -> None: self.model.save_pretrained(output_path, safe_serialization=safe_serialization) self.tokenizer.save_pretrained(output_path) self.image_processor.save_pretrained(output_path) @staticmethod def load(input_path: str) -> "Transformer": # Old classes used other config names than 'sentence_bert_config.json' for config_name in [ "sentence_bert_config.json", "sentence_roberta_config.json", "sentence_distilbert_config.json", "sentence_camembert_config.json", "sentence_albert_config.json", "sentence_xlm-roberta_config.json", "sentence_xlnet_config.json", ]: sbert_config_path = os.path.join(input_path, config_name) if os.path.exists(sbert_config_path): break with open(sbert_config_path) as fIn: config = json.load(fIn) # Don't allow configs to set trust_remote_code if "config_kwargs" in config and "trust_remote_code" in config["config_kwargs"]: config["config_kwargs"].pop("trust_remote_code") if "model_kwargs" in config and "trust_remote_code" in config["model_kwargs"]: config["model_kwargs"].pop("trust_remote_code") if ( "tokenizer_kwargs" in config and "trust_remote_code" in config["tokenizer_kwargs"] ): config["tokenizer_kwargs"].pop("trust_remote_code") if ( "image_processor_kwargs" in config and "trust_remote_code" in config["image_processor_kwargs"] ): config["image_processor_kwargs"].pop("trust_remote_code") return Transformer(model_name_or_path=input_path, **config)