Plachta commited on
Commit
06e7a0f
1 Parent(s): 52e32c0

updated requirements

Browse files
Files changed (2) hide show
  1. utils/download.py +49 -0
  2. utils/symbol_table.py +287 -0
utils/download.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import requests
3
+
4
+
5
+ def download_file_from_google_drive(id, destination):
6
+ URL = "https://docs.google.com/uc?export=download&confirm=1"
7
+
8
+ session = requests.Session()
9
+
10
+ response = session.get(URL, params={"id": id}, stream=True)
11
+ token = get_confirm_token(response)
12
+
13
+ if token:
14
+ params = {"id": id, "confirm": token}
15
+ response = session.get(URL, params=params, stream=True)
16
+
17
+ save_response_content(response, destination)
18
+
19
+
20
+ def get_confirm_token(response):
21
+ for key, value in response.cookies.items():
22
+ if key.startswith("download_warning"):
23
+ return value
24
+
25
+ return None
26
+
27
+
28
+ def save_response_content(response, destination):
29
+ CHUNK_SIZE = 32768
30
+
31
+ with open(destination, "wb") as f:
32
+ for chunk in response.iter_content(CHUNK_SIZE):
33
+ if chunk: # filter out keep-alive new chunks
34
+ f.write(chunk)
35
+
36
+
37
+ def main():
38
+ if len(sys.argv) >= 3:
39
+ file_id = sys.argv[1]
40
+ destination = sys.argv[2]
41
+ else:
42
+ file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
43
+ destination = "DESTINATION_FILE_ON_YOUR_DISK"
44
+ print(f"dowload {file_id} to {destination}")
45
+ download_file_from_google_drive(file_id, destination)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
utils/symbol_table.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2
+ #
3
+ # See ../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass
18
+ from dataclasses import field
19
+ from typing import Dict
20
+ from typing import Generic
21
+ from typing import List
22
+ from typing import Optional
23
+ from typing import TypeVar
24
+ from typing import Union
25
+
26
+ Symbol = TypeVar('Symbol')
27
+
28
+
29
+ # Disable __repr__ otherwise it could freeze e.g. Jupyter.
30
+ @dataclass(repr=False)
31
+ class SymbolTable(Generic[Symbol]):
32
+ '''SymbolTable that maps symbol IDs, found on the FSA arcs to
33
+ actual objects. These objects can be arbitrary Python objects
34
+ that can serve as keys in a dictionary (i.e. they need to be
35
+ hashable and immutable).
36
+
37
+ The SymbolTable can only be read to/written from disk if the
38
+ symbols are strings.
39
+ '''
40
+ _id2sym: Dict[int, Symbol] = field(default_factory=dict)
41
+ '''Map an integer to a symbol.
42
+ '''
43
+
44
+ _sym2id: Dict[Symbol, int] = field(default_factory=dict)
45
+ '''Map a symbol to an integer.
46
+ '''
47
+
48
+ _next_available_id: int = 1
49
+ '''A helper internal field that helps adding new symbols
50
+ to the table efficiently.
51
+ '''
52
+
53
+ eps: Symbol = '<eps>'
54
+ '''Null symbol, always mapped to index 0.
55
+ '''
56
+
57
+ def __post_init__(self):
58
+ for idx, sym in self._id2sym.items():
59
+ assert self._sym2id[sym] == idx
60
+ assert idx >= 0
61
+
62
+ for sym, idx in self._sym2id.items():
63
+ assert idx >= 0
64
+ assert self._id2sym[idx] == sym
65
+
66
+ if 0 not in self._id2sym:
67
+ self._id2sym[0] = self.eps
68
+ self._sym2id[self.eps] = 0
69
+ else:
70
+ assert self._id2sym[0] == self.eps
71
+ assert self._sym2id[self.eps] == 0
72
+
73
+ self._next_available_id = max(self._id2sym) + 1
74
+
75
+ @staticmethod
76
+ def from_str(s: str) -> 'SymbolTable':
77
+ '''Build a symbol table from a string.
78
+
79
+ The string consists of lines. Every line has two fields separated
80
+ by space(s), tab(s) or both. The first field is the symbol and the
81
+ second the integer id of the symbol.
82
+
83
+ Args:
84
+ s:
85
+ The input string with the format described above.
86
+ Returns:
87
+ An instance of :class:`SymbolTable`.
88
+ '''
89
+ id2sym: Dict[int, str] = dict()
90
+ sym2id: Dict[str, int] = dict()
91
+
92
+ for line in s.split('\n'):
93
+ fields = line.split()
94
+ if len(fields) == 0:
95
+ continue # skip empty lines
96
+ assert len(fields) == 2, \
97
+ f'Expect a line with 2 fields. Given: {len(fields)}'
98
+ sym, idx = fields[0], int(fields[1])
99
+ assert sym not in sym2id, f'Duplicated symbol {sym}'
100
+ assert idx not in id2sym, f'Duplicated id {idx}'
101
+ id2sym[idx] = sym
102
+ sym2id[sym] = idx
103
+
104
+ eps = id2sym.get(0, '<eps>')
105
+
106
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
107
+
108
+ @staticmethod
109
+ def from_file(filename: str) -> 'SymbolTable':
110
+ '''Build a symbol table from file.
111
+
112
+ Every line in the symbol table file has two fields separated by
113
+ space(s), tab(s) or both. The following is an example file:
114
+
115
+ .. code-block::
116
+
117
+ <eps> 0
118
+ a 1
119
+ b 2
120
+ c 3
121
+
122
+ Args:
123
+ filename:
124
+ Name of the symbol table file. Its format is documented above.
125
+
126
+ Returns:
127
+ An instance of :class:`SymbolTable`.
128
+
129
+ '''
130
+ with open(filename, 'r', encoding='utf-8') as f:
131
+ return SymbolTable.from_str(f.read().strip())
132
+
133
+ def to_str(self) -> str:
134
+ '''
135
+ Returns:
136
+ Return a string representation of this object. You can pass
137
+ it to the method ``from_str`` to recreate an identical object.
138
+ '''
139
+ s = ''
140
+ for idx, symbol in sorted(self._id2sym.items()):
141
+ s += f'{symbol} {idx}\n'
142
+ return s
143
+
144
+ def to_file(self, filename: str):
145
+ '''Serialize the SymbolTable to a file.
146
+
147
+ Every line in the symbol table file has two fields separated by
148
+ space(s), tab(s) or both. The following is an example file:
149
+
150
+ .. code-block::
151
+
152
+ <eps> 0
153
+ a 1
154
+ b 2
155
+ c 3
156
+
157
+ Args:
158
+ filename:
159
+ Name of the symbol table file. Its format is documented above.
160
+ '''
161
+ with open(filename, 'w') as f:
162
+ for idx, symbol in sorted(self._id2sym.items()):
163
+ print(symbol, idx, file=f)
164
+
165
+ def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
166
+ '''Add a new symbol to the SymbolTable.
167
+
168
+ Args:
169
+ symbol:
170
+ The symbol to be added.
171
+ index:
172
+ Optional int id to which the symbol should be assigned.
173
+ If it is not available, a ValueError will be raised.
174
+
175
+ Returns:
176
+ The int id to which the symbol has been assigned.
177
+ '''
178
+ # Already in the table? Return its ID.
179
+ if symbol in self._sym2id:
180
+ return self._sym2id[symbol]
181
+ # Specific ID not provided - use next available.
182
+ if index is None:
183
+ index = self._next_available_id
184
+ # Specific ID provided but not available.
185
+ if index in self._id2sym:
186
+ raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
187
+ f"already occupied by {self._id2sym[index]}")
188
+ self._sym2id[symbol] = index
189
+ self._id2sym[index] = symbol
190
+
191
+ # Update next available ID if needed
192
+ if self._next_available_id <= index:
193
+ self._next_available_id = index + 1
194
+
195
+ return index
196
+
197
+ def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
198
+ '''Get a symbol for an id or get an id for a symbol
199
+
200
+ Args:
201
+ k:
202
+ If it is an id, it tries to find the symbol corresponding
203
+ to the id; if it is a symbol, it tries to find the id
204
+ corresponding to the symbol.
205
+
206
+ Returns:
207
+ An id or a symbol depending on the given `k`.
208
+ '''
209
+ if isinstance(k, int):
210
+ return self._id2sym[k]
211
+ else:
212
+ return self._sym2id[k]
213
+
214
+ def merge(self, other: 'SymbolTable') -> 'SymbolTable':
215
+ '''Create a union of two SymbolTables.
216
+ Raises an AssertionError if the same IDs are occupied by
217
+ different symbols.
218
+
219
+ Args:
220
+ other:
221
+ A symbol table to merge with ``self``.
222
+
223
+ Returns:
224
+ A new symbol table.
225
+ '''
226
+ self._check_compatible(other)
227
+
228
+ id2sym = {**self._id2sym, **other._id2sym}
229
+ sym2id = {**self._sym2id, **other._sym2id}
230
+
231
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
232
+
233
+ def _check_compatible(self, other: 'SymbolTable') -> None:
234
+ # Epsilon compatibility
235
+ assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
236
+ f'{self.eps} != {other.eps}'
237
+ # IDs compatibility
238
+ common_ids = set(self._id2sym).intersection(other._id2sym)
239
+ for idx in common_ids:
240
+ assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
241
+ f'self[idx] = "{self[idx]}", ' \
242
+ f'other[idx] = "{other[idx]}"'
243
+ # Symbols compatibility
244
+ common_symbols = set(self._sym2id).intersection(other._sym2id)
245
+ for sym in common_symbols:
246
+ assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
247
+ f'self[sym] = "{self[sym]}", ' \
248
+ f'other[sym] = "{other[sym]}"'
249
+
250
+ def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
251
+ return self.get(item)
252
+
253
+ def __contains__(self, item: Union[int, Symbol]) -> bool:
254
+ if isinstance(item, int):
255
+ return item in self._id2sym
256
+ else:
257
+ return item in self._sym2id
258
+
259
+ def __len__(self) -> int:
260
+ return len(self._id2sym)
261
+
262
+ def __eq__(self, other: 'SymbolTable') -> bool:
263
+ if len(self) != len(other):
264
+ return False
265
+
266
+ for s in self.symbols:
267
+ if self[s] != other[s]:
268
+ return False
269
+
270
+ return True
271
+
272
+ @property
273
+ def ids(self) -> List[int]:
274
+ '''Returns a list of integer IDs corresponding to the symbols.
275
+ '''
276
+ ans = list(self._id2sym.keys())
277
+ ans.sort()
278
+ return ans
279
+
280
+ @property
281
+ def symbols(self) -> List[Symbol]:
282
+ '''Returns a list of symbols (e.g., strings) corresponding to
283
+ the integer IDs.
284
+ '''
285
+ ans = list(self._sym2id.keys())
286
+ ans.sort()
287
+ return ans