Garrett Goon commited on
Commit
8680dd4
1 Parent(s): d58e63f
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. learned_embeddings_dict.py +0 -0
  3. utils.py +59 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
learned_embeddings_dict.py ADDED
Binary file (16.2 kB). View file
 
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def add_new_tokens_to_tokenizer(
8
+ concept_token: str,
9
+ initializer_tokens: Sequence[str],
10
+ tokenizer: nn.Module,
11
+ ) -> Tuple[List[int], List[int], str]:
12
+ """Helper function for adding new tokens to the tokenizer and extending the corresponding
13
+ embeddings appropriately, given a single concept token and its sequence of corresponding
14
+ initializer tokens. Returns the lists of ids for the initializer tokens and their dummy
15
+ replacements, as well as the string representation of the dummies.
16
+ """
17
+ initializer_ids = tokenizer(
18
+ initializer_tokens,
19
+ padding="max_length",
20
+ truncation=True,
21
+ max_length=tokenizer.model_max_length,
22
+ return_tensors="pt",
23
+ add_special_tokens=False,
24
+ ).input_ids
25
+
26
+ try:
27
+ special_token_ids = tokenizer.all_special_ids
28
+ except AttributeError:
29
+ special_token_ids = []
30
+
31
+ non_special_initializer_locations = torch.isin(
32
+ initializer_ids, torch.tensor(special_token_ids), invert=True
33
+ )
34
+ non_special_initializer_ids = initializer_ids[non_special_initializer_locations]
35
+ if len(non_special_initializer_ids) == 0:
36
+ raise ValueError(
37
+ f'"{initializer_tokens}" maps to trivial tokens, please choose a different initializer.'
38
+ )
39
+
40
+ # Add a dummy placeholder token for every token in the initializer.
41
+ dummy_placeholder_token_list = [
42
+ f"{concept_token}_{n}" for n in range(len(non_special_initializer_ids))
43
+ ]
44
+ dummy_placeholder_tokens = " ".join(dummy_placeholder_token_list)
45
+ num_added_tokens = tokenizer.add_tokens(dummy_placeholder_token_list)
46
+ if num_added_tokens != len(dummy_placeholder_token_list):
47
+ raise ValueError(
48
+ f"Subset of {dummy_placeholder_token_list} tokens already exist in tokenizer."
49
+ )
50
+
51
+ dummy_placeholder_ids = tokenizer.convert_tokens_to_ids(
52
+ dummy_placeholder_token_list
53
+ )
54
+ # Sanity check
55
+ assert len(dummy_placeholder_ids) == len(
56
+ non_special_initializer_ids
57
+ ), 'Length of "dummy_placeholder_ids" and "non_special_initializer_ids" must match.'
58
+
59
+ return non_special_initializer_ids, dummy_placeholder_ids, dummy_placeholder_tokens