import os import torch from functools import lru_cache from pathlib import Path from typing import Union from . import modeling_finetune, utils from PIL import Image from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.models import create_model from torchvision import transforms from transformers import XLMRobertaTokenizer # Get current workdir of this file CWD = Path(__file__).parent print(CWD) class Preprocess: def __init__(self, tokenizer): self.max_len = 64 self.input_size = 384 self.tokenizer = tokenizer self.transform = transforms.Compose( [ transforms.Resize((self.input_size, self.input_size), interpolation=3), transforms.ToTensor(), transforms.Normalize( mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD ), ] ) self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id self.pad_token_id = tokenizer.pad_token_id def preprocess(self, input: Union[str, Image.Image]): if isinstance(input, str): tokens = self.tokenizer.tokenize(input) tokens = self.tokenizer.convert_tokens_to_ids(tokens) tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id] num_tokens = len(tokens) padding_mask = [0] * num_tokens + [1] * (self.max_len - num_tokens) return ( torch.LongTensor( tokens + [self.pad_token_id] * (self.max_len - num_tokens) ).unsqueeze(0), torch.Tensor(padding_mask).unsqueeze(0), num_tokens, ) elif isinstance(input, Image.Image): return self.transform(input).unsqueeze(0) else: raise Exception("Invalid input type") class Beit3Model: def __init__( self, model_name: str = "beit3_base_patch16_384_retrieval", model_path: str = os.path.join( CWD, "beit3_model/beit3_base_patch16_384_f30k_retrieval.pth", ), device: str = "cuda", ): self._load_model(model_name, model_path, device) self.device = device # @lru_cache(maxsize=1) def _load_model(self, model_name, model_path, device: str = "cpu"): self.model = create_model( model_name, pretrained=False, drop_path_rate=0.1, vocab_size=64010, checkpoint_activations=False, ) if model_name: utils.load_model_and_may_interpolate( model_path, self.model, "model|module", "" ) self.preprocessor = Preprocess( XLMRobertaTokenizer(os.path.join(CWD, "beit3_model/beit3.spm")) ) self.model.to(device) def get_embedding(self, input: Union[str, Image.Image]): if isinstance(input, str): token_ids, padding_mask, _ = self.preprocessor.preprocess(input) _, vector = self.model( text_description=token_ids, padding_mask=padding_mask, only_infer=True ) vector = vector.cpu().detach().numpy().astype("float32") return vector elif isinstance(input, Image.Image): image_input = self.preprocessor.preprocess(input) image_input = image_input.to(self.device) vector, _ = self.model(image=image_input, only_infer=True) return vector else: raise Exception("Invalid input type")