| """ |
| The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing |
| a photo of <concept>_0 <concept>_1 ... and so on |
| and instead just do |
| a photo of <concept> |
| which gets translated to the above. This needs to work for both inference and training. |
| For inference, |
| the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with |
| it's underlying vectors |
| For training, |
| we would want to abstract away some logic like |
| 1. Adding tokens |
| 2. Updating gradient mask |
| 3. Saving embeddings |
| to our Util class here. |
| so |
| TODO: |
| 1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x |
| 2. have mechanism for adding tokens x |
| 3. have mech for saving emebeddings x |
| 4. get mask to update x |
| 5. Loading tokens from embedding x |
| 6. Integrate to training x |
| 7. Test |
| """ |
|
|
| import copy |
| import random |
|
|
| from transformers import CLIPTokenizer |
|
|
|
|
| class MultiTokenCLIPTokenizer(CLIPTokenizer): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.token_map = {} |
|
|
| def try_adding_tokens(self, placeholder_token, *args, **kwargs): |
| num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs) |
| if num_added_tokens == 0: |
| raise ValueError( |
| f"The tokenizer already contains the token {placeholder_token}. Please pass a different" |
| " `placeholder_token` that is not already in the tokenizer." |
| ) |
|
|
| def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs): |
| output = [] |
| if num_vec_per_token == 1: |
| self.try_adding_tokens(placeholder_token, *args, **kwargs) |
| output.append(placeholder_token) |
| else: |
| output = [] |
| for i in range(num_vec_per_token): |
| ith_token = placeholder_token + f"_{i}" |
| self.try_adding_tokens(ith_token, *args, **kwargs) |
| output.append(ith_token) |
| |
| for token in self.token_map: |
| if token in placeholder_token: |
| raise ValueError( |
| f"The tokenizer already has placeholder token {token} that can get confused with" |
| f" {placeholder_token}keep placeholder tokens independent" |
| ) |
| self.token_map[placeholder_token] = output |
|
|
| def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0): |
| """ |
| Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder |
| can encode them |
| vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119 |
| where shuffling tokens were found to force the model to learn the concepts more descriptively. |
| """ |
| if isinstance(text, list): |
| output = [] |
| for i in range(len(text)): |
| output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) |
| return output |
| for placeholder_token in self.token_map: |
| if placeholder_token in text: |
| tokens = self.token_map[placeholder_token] |
| tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] |
| if vector_shuffle: |
| tokens = copy.copy(tokens) |
| random.shuffle(tokens) |
| text = text.replace(placeholder_token, " ".join(tokens)) |
| return text |
|
|
| def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): |
| return super().__call__( |
| self.replace_placeholder_tokens_in_text( |
| text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load |
| ), |
| *args, |
| **kwargs, |
| ) |
|
|
| def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): |
| return super().encode( |
| self.replace_placeholder_tokens_in_text( |
| text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load |
| ), |
| *args, |
| **kwargs, |
| ) |
|
|