File size: 1,328 Bytes
495fe55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import Dict
import open_clip
from torch import Tensor, nn
class ClipTokenizer(nn.Module):
def __init__(self, cfg, *args, **kwargs):
super().__init__()
self.context_length = cfg["text_cfg"]["context_length"]
model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
self.tokenizer = open_clip.get_tokenizer(model_name)
def get_vocab_size(self) -> int:
return len(self.tokenizer.encoder)
def get_encodings(self) -> Dict[str, int]:
return self.tokenizer.encoder
def get_eot_token(self) -> int:
# Tokenizing an empty string returns a list [sot_id, eot_id]
return self.tokenizer("")[1]
def get_sot_token(self) -> int:
# Tokenizing an empty string returns a list [sot_id, eot_id]
return self.tokenizer("")[0]
def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
# tokenizer returns indices as a string
tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
assert (
tokenized_sentence.shape[-1] == self.context_length
), "Tokenized tensor should be exactly `context_length` long."
return tokenized_sentence
|