nicoladecao commited on
Commit
7d15402
1 Parent(s): 6186793

Create trie.py

Browse files
Files changed (1) hide show
  1. trie.py +93 -0
trie.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree at
6
+ # https://github.com/facebookresearch/GENRE .
7
+
8
+
9
+ from typing import Dict, List
10
+
11
+
12
+ class Trie(object):
13
+ def __init__(self, sequences: List[List[int]] = []):
14
+ self.trie_dict = {}
15
+ self.len = 0
16
+ if sequences:
17
+ for sequence in sequences:
18
+ Trie._add_to_trie(sequence, self.trie_dict)
19
+ self.len += 1
20
+
21
+ self.append_trie = None
22
+ self.bos_token_id = None
23
+
24
+ def append(self, trie, bos_token_id):
25
+ self.append_trie = trie
26
+ self.bos_token_id = bos_token_id
27
+
28
+ def add(self, sequence: List[int]):
29
+ Trie._add_to_trie(sequence, self.trie_dict)
30
+ self.len += 1
31
+
32
+ def get(self, prefix_sequence: List[int]):
33
+ return Trie._get_from_trie(
34
+ prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
35
+ )
36
+
37
+ @staticmethod
38
+ def load_from_dict(trie_dict):
39
+ trie = Trie()
40
+ trie.trie_dict = trie_dict
41
+ trie.len = sum(1 for _ in trie)
42
+ return trie
43
+
44
+ @staticmethod
45
+ def _add_to_trie(sequence: List[int], trie_dict: Dict):
46
+ if sequence:
47
+ if sequence[0] not in trie_dict:
48
+ trie_dict[sequence[0]] = {}
49
+ Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
50
+
51
+ @staticmethod
52
+ def _get_from_trie(
53
+ prefix_sequence: List[int],
54
+ trie_dict: Dict,
55
+ append_trie=None,
56
+ bos_token_id: int = None,
57
+ ):
58
+ if len(prefix_sequence) == 0:
59
+ output = list(trie_dict.keys())
60
+ if append_trie and bos_token_id in output:
61
+ output.remove(bos_token_id)
62
+ output += list(append_trie.trie_dict.keys())
63
+ return output
64
+ elif prefix_sequence[0] in trie_dict:
65
+ return Trie._get_from_trie(
66
+ prefix_sequence[1:],
67
+ trie_dict[prefix_sequence[0]],
68
+ append_trie,
69
+ bos_token_id,
70
+ )
71
+ else:
72
+ if append_trie:
73
+ return append_trie.get(prefix_sequence)
74
+ else:
75
+ return []
76
+
77
+ def __iter__(self):
78
+ def _traverse(prefix_sequence, trie_dict):
79
+ if trie_dict:
80
+ for next_token in trie_dict:
81
+ yield from _traverse(
82
+ prefix_sequence + [next_token], trie_dict[next_token]
83
+ )
84
+ else:
85
+ yield prefix_sequence
86
+
87
+ return _traverse([], self.trie_dict)
88
+
89
+ def __len__(self):
90
+ return self.len
91
+
92
+ def __getitem__(self, value):
93
+ return self.get(value)