""" The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing a photo of _0 _1 ... and so on and instead just do a photo of which gets translated to the above. This needs to work for both inference and training. For inference, the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with it's underlying vectors For training, we would want to abstract away some logic like 1. Adding tokens 2. Updating gradient mask 3. Saving embeddings to our Util class here. so TODO: 1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x 2. have mechanism for adding tokens x 3. have mech for saving emebeddings x 4. get mask to update x 5. Loading tokens from embedding x 6. Integrate to training x 7. Test """ import copy import random from transformers import CLIPTokenizer class MultiTokenCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map = {} def try_adding_tokens(self, placeholder_token, *args, **kwargs): num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs) if num_added_tokens == 0: raise ValueError( f"The tokenizer already contains the token {placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." ) def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs): output = [] if num_vec_per_token == 1: self.try_adding_tokens(placeholder_token, *args, **kwargs) output.append(placeholder_token) else: output = [] for i in range(num_vec_per_token): ith_token = placeholder_token + f"_{i}" self.try_adding_tokens(ith_token, *args, **kwargs) output.append(ith_token) # handle cases where there is a new placeholder token that contains the current placeholder token but is larger for token in self.token_map: if token in placeholder_token: raise ValueError( f"The tokenizer already has placeholder token {token} that can get confused with" f" {placeholder_token}keep placeholder tokens independent" ) self.token_map[placeholder_token] = output def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0): """ Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder can encode them vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119 where shuffling tokens were found to force the model to learn the concepts more descriptively. """ if isinstance(text, list): output = [] for i in range(len(text)): output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) return output for placeholder_token in self.token_map: if placeholder_token in text: tokens = self.token_map[placeholder_token] tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] if vector_shuffle: tokens = copy.copy(tokens) random.shuffle(tokens) text = text.replace(placeholder_token, " ".join(tokens)) return text def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): return super().__call__( self.replace_placeholder_tokens_in_text( text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load ), *args, **kwargs, ) def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): return super().encode( self.replace_placeholder_tokens_in_text( text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load ), *args, **kwargs, )