from typing import Union import torch from transformers import AutoTokenizer from src.config import TinyCLIPTextConfig class Tokenizer: def __init__(self, text_config: TinyCLIPTextConfig) -> None: self.tokenizer = AutoTokenizer.from_pretrained(text_config.text_model) self.max_len = text_config.max_len def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]: return self.tokenizer( x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt" ) # type: ignore def decode(self, x: dict[str, torch.LongTensor]) -> list[str]: return [ self.tokenizer.decode(sentence[:sentence_len]) for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1)) ]