andybi7676 commited on
Commit
a1c5b19
·
verified ·
1 Parent(s): 9592743

Upload model

Browse files
Files changed (4) hide show
  1. config.json +45 -2
  2. configuration_reborn.py +29 -0
  3. modeling_reborn.py +198 -1
  4. pytorch_model.bin +2 -2
config.json CHANGED
@@ -12,7 +12,7 @@
12
  "discriminator_dilation": 1,
13
  "discriminator_dim": 256,
14
  "discriminator_dropout": 0.0,
15
- "discriminator_input_dim": 512,
16
  "discriminator_kernel": 3,
17
  "discriminator_linear_emb": false,
18
  "discriminator_max_pool": false,
@@ -25,14 +25,57 @@
25
  "generator_dropout": 0.0,
26
  "generator_input_dim": 512,
27
  "generator_kernel": 4,
28
- "generator_output_dim": 40,
29
  "generator_stride": 1,
30
  "model_type": "reborn_uasr",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  "segmenter_dropout": 0.1,
32
  "segmenter_hidden_dim": 512,
33
  "segmenter_input_dim": 512,
34
  "segmenter_kernel_size": 7,
35
  "segmenter_type": "cnn",
 
36
  "torch_dtype": "float32",
37
  "transformers_version": "4.24.0"
38
  }
 
12
  "discriminator_dilation": 1,
13
  "discriminator_dim": 256,
14
  "discriminator_dropout": 0.0,
15
+ "discriminator_input_dim": 44,
16
  "discriminator_kernel": 3,
17
  "discriminator_linear_emb": false,
18
  "discriminator_max_pool": false,
 
25
  "generator_dropout": 0.0,
26
  "generator_input_dim": 512,
27
  "generator_kernel": 4,
28
+ "generator_output_dim": 44,
29
  "generator_stride": 1,
30
  "model_type": "reborn_uasr",
31
+ "phones": [
32
+ "AH",
33
+ "N",
34
+ "S",
35
+ "IH",
36
+ "T",
37
+ "L",
38
+ "R",
39
+ "D",
40
+ "K",
41
+ "IY",
42
+ "Z",
43
+ "M",
44
+ "ER",
45
+ "EH",
46
+ "P",
47
+ "AE",
48
+ "B",
49
+ "AA",
50
+ "EY",
51
+ "F",
52
+ "OW",
53
+ "NG",
54
+ "G",
55
+ "V",
56
+ "AO",
57
+ "AY",
58
+ "SH",
59
+ "UW",
60
+ "W",
61
+ "HH",
62
+ "JH",
63
+ "Y",
64
+ "CH",
65
+ "TH",
66
+ "AW",
67
+ "UH",
68
+ "OY",
69
+ "DH",
70
+ "ZH",
71
+ "<SIL>"
72
+ ],
73
  "segmenter_dropout": 0.1,
74
  "segmenter_hidden_dim": 512,
75
  "segmenter_input_dim": 512,
76
  "segmenter_kernel_size": 7,
77
  "segmenter_type": "cnn",
78
+ "special_token_nums": 4,
79
  "torch_dtype": "float32",
80
  "transformers_version": "4.24.0"
81
  }
configuration_reborn.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import PretrainedConfig
2
 
3
  class RebornUASRConfig(PretrainedConfig):
@@ -37,6 +38,10 @@ class RebornUASRConfig(PretrainedConfig):
37
  generator_dropout: float = 0.0,
38
  generator_bn_apply: bool = False,
39
  generator_bn_init_weight: float = 30.0,
 
 
 
 
40
  **kwargs
41
  ):
42
  super().__init__(**kwargs)
@@ -70,3 +75,27 @@ class RebornUASRConfig(PretrainedConfig):
70
  self.generator_bn_apply = generator_bn_apply
71
  self.generator_bn_init_weight = generator_bn_init_weight
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  from transformers import PretrainedConfig
3
 
4
  class RebornUASRConfig(PretrainedConfig):
 
38
  generator_dropout: float = 0.0,
39
  generator_bn_apply: bool = False,
40
  generator_bn_init_weight: float = 30.0,
41
+
42
+ phones: list = [],
43
+ dict_fpath: str = "",
44
+ special_token_nums: int = 4, # [<s>, <pad>, </s>, <unk>]
45
  **kwargs
46
  ):
47
  super().__init__(**kwargs)
 
75
  self.generator_bn_apply = generator_bn_apply
76
  self.generator_bn_init_weight = generator_bn_init_weight
77
 
78
+ self.special_token_nums = special_token_nums
79
+ if os.path.isfile(dict_fpath):
80
+ self.phones = self.read_phns_dict_from_fpath(dict_fpath)
81
+ else:
82
+ self.phones = phones
83
+ if len(self.phones) > 0:
84
+ self.generator_output_dim = len(self.phones) + self.special_token_nums
85
+ self.discriminator_input_dim = self.generator_output_dim
86
+
87
+ def read_phns_dict_from_fpath(self, fpath: str):
88
+ phns = []
89
+ with open(fpath, "r") as f:
90
+ for l in f:
91
+ phn = l.strip().split('\t')[0].split(' ')[0]
92
+ phns.append(phn)
93
+ return phns
94
+
95
+ def main():
96
+ config = RebornUASRConfig(dict_fpath="/home/andybi7676/Desktop/uasr-rl/data/ls_100h_new/text/prep/phones/dict.phn.txt")
97
+ print(config)
98
+ config.save_pretrained("reborn_uasr")
99
+
100
+ if __name__ == "__main__":
101
+ main()
modeling_reborn.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
  from .configuration_reborn import RebornUASRConfig
5
- from typing import Optional, Tuple, Union
6
 
7
  class RebornSegmenter(nn.Module):
8
  def __init__(self, config):
@@ -158,6 +158,176 @@ class RebornGenerator(nn.Module):
158
 
159
  return result
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  class RebornUASRModel(PreTrainedModel):
162
  config_class = RebornUASRConfig
163
 
@@ -166,6 +336,9 @@ class RebornUASRModel(PreTrainedModel):
166
  self.pca = nn.Linear(1024, 512)
167
  self.segmenter = RebornSegmenter(config)
168
  self.generator = RebornGenerator(config)
 
 
 
169
 
170
  def forward(
171
  self,
@@ -181,4 +354,28 @@ class RebornUASRModel(PreTrainedModel):
181
  'x_segmented': x_segmented,
182
  'x_generated': x_generated
183
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
 
 
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
  from .configuration_reborn import RebornUASRConfig
5
+ from typing import Optional, Tuple, Union, List
6
 
7
  class RebornSegmenter(nn.Module):
8
  def __init__(self, config):
 
158
 
159
  return result
160
 
161
+ def get_item(tensor):
162
+ # tpu-comment: making this a no-op for xla devices.
163
+ if torch.is_tensor(tensor) and tensor.device.type == "xla":
164
+ return tensor.detach()
165
+ if hasattr(tensor, "item"):
166
+ return tensor.item()
167
+ if hasattr(tensor, "__getitem__"):
168
+ return tensor[0]
169
+ return tensor
170
+
171
+ def post_process(sentence: str, symbol: str):
172
+ if symbol == "sentencepiece":
173
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
174
+ elif symbol == "wordpiece":
175
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
176
+ elif symbol == "letter":
177
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
178
+ elif symbol == "silence":
179
+ import re
180
+ sentence = sentence.replace("<SIL>", "")
181
+ sentence = re.sub(' +', ' ', sentence).strip()
182
+ elif symbol == "_EOW":
183
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
184
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
185
+ if symbol == "subword_nmt":
186
+ symbol = "@@ "
187
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
188
+ elif symbol == "none":
189
+ pass
190
+ elif symbol is not None:
191
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
192
+ return sentence
193
+
194
+ class SimpleTokenizer(object):
195
+ def __init__(self,
196
+ phones: List[str],
197
+ bos="<s>",
198
+ pad="<pad>",
199
+ eos="</s>",
200
+ unk="<unk>",
201
+ extra_special_symbols=None,
202
+ ) -> None:
203
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
204
+ self.symbols = []
205
+ self.count = []
206
+ self.indices = {}
207
+ self.bos_index = self.add_symbol(bos)
208
+ self.pad_index = self.add_symbol(pad)
209
+ self.eos_index = self.add_symbol(eos)
210
+ self.unk_index = self.add_symbol(unk)
211
+ if extra_special_symbols:
212
+ for s in extra_special_symbols:
213
+ self.add_symbol(s)
214
+ self.nspecial = len(self.symbols)
215
+ for phone in phones:
216
+ self.add_symbol(phone)
217
+ self.postprocess_code = "silence"
218
+
219
+ def add_symbol(self, word, n=1, overwrite=False):
220
+ """Adds a word to the dictionary"""
221
+ if word in self.indices and not overwrite:
222
+ idx = self.indices[word]
223
+ self.count[idx] = self.count[idx] + n
224
+ return idx
225
+ else:
226
+ idx = len(self.symbols)
227
+ self.indices[word] = idx
228
+ self.symbols.append(word)
229
+ self.count.append(n)
230
+ return idx
231
+
232
+ def __eq__(self, other):
233
+ return self.indices == other.indices
234
+
235
+ def __getitem__(self, idx):
236
+ if idx < len(self.symbols):
237
+ return self.symbols[idx]
238
+ return self.unk_word
239
+
240
+ def get_count(self, idx):
241
+ return self.count[idx]
242
+
243
+ def __len__(self):
244
+ """Returns the number of symbols in the dictionary"""
245
+ return len(self.symbols)
246
+
247
+ def __contains__(self, sym):
248
+ return sym in self.indices
249
+
250
+ def index(self, sym):
251
+ """Returns the index of the specified symbol"""
252
+ assert isinstance(sym, str)
253
+ if sym in self.indices:
254
+ return self.indices[sym]
255
+ return self.unk_index
256
+
257
+ def string(
258
+ self,
259
+ tensor,
260
+ bpe_symbol=None,
261
+ escape_unk=False,
262
+ extra_symbols_to_ignore=None,
263
+ unk_string=None,
264
+ include_eos=False,
265
+ separator=" ",
266
+ ):
267
+ """Helper for converting a tensor of token indices to a string.
268
+
269
+ Can optionally remove BPE symbols or escape <unk> words.
270
+ """
271
+ if torch.is_tensor(tensor) and tensor.dim() == 2:
272
+ return "\n".join(
273
+ self.string(
274
+ t,
275
+ bpe_symbol,
276
+ escape_unk,
277
+ extra_symbols_to_ignore,
278
+ include_eos=include_eos,
279
+ )
280
+ for t in tensor
281
+ )
282
+
283
+ extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
284
+ if not include_eos:
285
+ extra_symbols_to_ignore.add(self.eos())
286
+
287
+ def token_string(i):
288
+ if i == self.unk():
289
+ if unk_string is not None:
290
+ return unk_string
291
+ else:
292
+ return self.unk_string(escape_unk)
293
+ else:
294
+ return self[i]
295
+
296
+ if hasattr(self, "bos_index"):
297
+ extra_symbols_to_ignore.add(self.bos())
298
+
299
+ sent = separator.join(
300
+ token_string(i)
301
+ for i in tensor
302
+ if get_item(i) not in extra_symbols_to_ignore
303
+ )
304
+
305
+ return post_process(sent, bpe_symbol)
306
+
307
+ def unk_string(self, escape=False):
308
+ """Return unknown string, optionally escaped as: <<unk>>"""
309
+ if escape:
310
+ return "<{}>".format(self.unk_word)
311
+ else:
312
+ return self.unk_word
313
+
314
+ def bos(self):
315
+ """Helper to get index of beginning-of-sentence symbol"""
316
+ return self.bos_index
317
+
318
+ def pad(self):
319
+ """Helper to get index of pad symbol"""
320
+ return self.pad_index
321
+
322
+ def eos(self):
323
+ """Helper to get index of end-of-sentence symbol"""
324
+ return self.eos_index
325
+
326
+ def unk(self):
327
+ """Helper to get index of unk symbol"""
328
+ return self.unk_index
329
+
330
+
331
  class RebornUASRModel(PreTrainedModel):
332
  config_class = RebornUASRConfig
333
 
 
336
  self.pca = nn.Linear(1024, 512)
337
  self.segmenter = RebornSegmenter(config)
338
  self.generator = RebornGenerator(config)
339
+ self.tokenizer = None
340
+ if len(config.phones) > 0:
341
+ self.tokenizer = SimpleTokenizer(config.phones)
342
 
343
  def forward(
344
  self,
 
354
  'x_segmented': x_segmented,
355
  'x_generated': x_generated
356
  }
357
+
358
+ def generate(self, x, padding_mask, merge_consecutive=True, remove_silence=True):
359
+ res = self.forward(x, padding_mask)
360
+ y_raw_logits = res['x_generated']['dense_x']
361
+ y_raw_padding = res['x_generated']['dense_padding_mask']
362
+ y_raw_logits[y_raw_padding][..., self.tokenizer.pad_index] = float('inf')
363
+ preds = y_raw_logits.argmax(-1)
364
+ hyps = []
365
+ postprocess_code = "silence" if remove_silence else "none"
366
+ for pred in preds:
367
+ if merge_consecutive:
368
+ # merge consecutive predictions
369
+ pred = torch.unique_consecutive(pred)
370
+ hyp = self.tokenizer.string(pred, bpe_symbol=postprocess_code)
371
+ hyps.append(hyp)
372
+ return hyps
373
+
374
+ def main():
375
+ model_config = RebornUASRConfig.from_pretrained("/home/andybi7676/Desktop/uasr-rl/reborn_uasr/config.json")
376
+ print(model_config)
377
+ model = RebornUASRModel(model_config)
378
+ print(model.tokenizer.indices)
379
 
380
+ if __name__ == "__main__":
381
+ main()
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:319b78a86e5743fd8239760ea2628d8d64cd4bfe293423e2acaa204ee6954f4f
3
- size 12923917
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f6ef4288440fc0e67b955fa0ffabdf48f8762577f304fd72ffd03131c5c840d
3
+ size 12956685